mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
Compare commits
79 Commits
use-conntr
...
debug-user
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
24053750d9 | ||
|
|
d5081cef90 | ||
|
|
488e619ec7 | ||
|
|
d2b42c8f68 | ||
|
|
2f44fe2e23 | ||
|
|
d8dc107bee | ||
|
|
3fa915e271 | ||
|
|
47c3afe561 | ||
|
|
84bfecdd37 | ||
|
|
3cf87b6846 | ||
|
|
4fe4c2054d | ||
|
|
38ada44a0e | ||
|
|
dbf81a145e | ||
|
|
39483f8ca8 | ||
|
|
c0eaea938e | ||
|
|
ef8b8a2891 | ||
|
|
2817f62c13 | ||
|
|
4a9049566a | ||
|
|
85f92f8321 | ||
|
|
714beb6e3b | ||
|
|
400b9fca32 | ||
|
|
4013298e22 | ||
|
|
312bfd9bd7 | ||
|
|
8db05838ca | ||
|
|
c69df13515 | ||
|
|
986eb8c1e0 | ||
|
|
197761ba4d | ||
|
|
f74ea64c7b | ||
|
|
3b7b9d25bc | ||
|
|
1a6d6b3109 | ||
|
|
f686615876 | ||
|
|
a4311f574d | ||
|
|
0bb8eae903 | ||
|
|
e0b33d325d | ||
|
|
c38e07d89a | ||
|
|
a37368fff4 | ||
|
|
0c93bd3d06 | ||
|
|
a675531b5c | ||
|
|
7cb366bc7d | ||
|
|
a354004564 | ||
|
|
75bdd47dfb | ||
|
|
b165f63327 | ||
|
|
51bb52cdf5 | ||
|
|
4134b857b4 | ||
|
|
7839d2c169 | ||
|
|
b9f82e2f8a | ||
|
|
fd2a21c65d | ||
|
|
82d982b0ab | ||
|
|
9e24fe7701 | ||
|
|
e470701b80 | ||
|
|
e3ce026355 | ||
|
|
5ea2806663 | ||
|
|
d6b0673580 | ||
|
|
14913cfa7a | ||
|
|
03f600b576 | ||
|
|
192c97aa63 | ||
|
|
4db78db49a | ||
|
|
87e600a4f3 | ||
|
|
6162aeb82d | ||
|
|
1ba1e092ce | ||
|
|
86dbb4ee4f | ||
|
|
4af177215f | ||
|
|
df9c1b9883 | ||
|
|
5752bb78f2 | ||
|
|
fbd783ad58 | ||
|
|
80702b9323 | ||
|
|
09243a0fe0 | ||
|
|
3658215747 | ||
|
|
48ffec95dd | ||
|
|
cbec7bda80 | ||
|
|
21464ac770 | ||
|
|
ed5647028a | ||
|
|
29a6e5be71 | ||
|
|
6124e3b937 | ||
|
|
50f5cc48cd | ||
|
|
101cce27f2 | ||
|
|
a4f04f5570 | ||
|
|
fceb3ca392 | ||
|
|
34d86c5ab8 |
27
.git-branches.toml
Normal file
27
.git-branches.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
# More info around this file at https://www.git-town.com/configuration-file
|
||||
|
||||
[branches]
|
||||
main = "main"
|
||||
perennials = []
|
||||
perennial-regex = ""
|
||||
|
||||
[create]
|
||||
new-branch-type = "feature"
|
||||
push-new-branches = false
|
||||
|
||||
[hosting]
|
||||
dev-remote = "origin"
|
||||
# platform = ""
|
||||
# origin-hostname = ""
|
||||
|
||||
[ship]
|
||||
delete-tracking-branch = false
|
||||
strategy = "squash-merge"
|
||||
|
||||
[sync]
|
||||
feature-strategy = "merge"
|
||||
perennial-strategy = "rebase"
|
||||
prototype-strategy = "merge"
|
||||
push-hook = true
|
||||
tags = true
|
||||
upstream = false
|
||||
4
.github/pull_request_template.md
vendored
4
.github/pull_request_template.md
vendored
@@ -2,6 +2,10 @@
|
||||
|
||||
## Issue ticket number and link
|
||||
|
||||
## Stack
|
||||
|
||||
<!-- branch-stack -->
|
||||
|
||||
### Checklist
|
||||
- [ ] Is it a bug fix
|
||||
- [ ] Is a typo/documentation fix
|
||||
|
||||
21
.github/workflows/git-town.yml
vendored
Normal file
21
.github/workflows/git-town.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Git Town
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- '**'
|
||||
|
||||
jobs:
|
||||
git-town:
|
||||
name: Display the branch stack
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: git-town/action@v1
|
||||
with:
|
||||
skip-single-stacks: true
|
||||
10
.github/workflows/golang-test-freebsd.yml
vendored
10
.github/workflows/golang-test-freebsd.yml
vendored
@@ -22,14 +22,20 @@ jobs:
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.1"
|
||||
release: "14.2"
|
||||
prepare: |
|
||||
pkg install -y go pkgconf xorg
|
||||
pkg install -y curl pkgconf xorg
|
||||
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
|
||||
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
run: |
|
||||
set -e -x
|
||||
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
||||
time go build -o netbird client/main.go
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
|
||||
225
.github/workflows/golang-test-linux.yml
vendored
225
.github/workflows/golang-test-linux.yml
vendored
@@ -146,6 +146,65 @@ jobs:
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [build-cache]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
id: go-env
|
||||
run: |
|
||||
echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT
|
||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
${{ steps.go-env.outputs.cache_dir }}
|
||||
${{ steps.go-env.outputs.modcache_dir }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Run tests in container
|
||||
env:
|
||||
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
|
||||
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
|
||||
run: |
|
||||
CONTAINER_GOCACHE="/root/.cache/go-build"
|
||||
CONTAINER_GOMODCACHE="/go/pkg/mod"
|
||||
|
||||
docker run --rm \
|
||||
--cap-add=NET_ADMIN \
|
||||
--privileged \
|
||||
-v $PWD:/app \
|
||||
-w /app \
|
||||
-v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \
|
||||
-v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \
|
||||
-e CGO_ENABLED=1 \
|
||||
-e CI=true \
|
||||
-e DOCKER_CI=true \
|
||||
-e GOARCH=${GOARCH_TARGET} \
|
||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||
golang:1.23-alpine \
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
name: "Relay / Unit"
|
||||
needs: [build-cache]
|
||||
@@ -179,13 +238,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -232,13 +284,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -286,13 +331,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -314,6 +352,7 @@ jobs:
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=devcert \
|
||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||
-timeout 20m ./management/...
|
||||
@@ -353,13 +392,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -380,10 +412,11 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags devcert -run=^$ -bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./...
|
||||
-timeout 20m ./management/...
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
@@ -396,6 +429,33 @@ jobs:
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
run: docker network create promnet
|
||||
|
||||
- name: Start Prometheus Pushgateway
|
||||
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
|
||||
|
||||
- name: Start Prometheus (for Pushgateway forwarding)
|
||||
run: |
|
||||
echo '
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
scrape_configs:
|
||||
- job_name: "pushgateway"
|
||||
static_configs:
|
||||
- targets: ["pushgateway:9091"]
|
||||
remote_write:
|
||||
- url: ${{ secrets.GRAFANA_URL }}
|
||||
basic_auth:
|
||||
username: ${{ secrets.GRAFANA_USER }}
|
||||
password: ${{ secrets.GRAFANA_API_KEY }}
|
||||
' > prometheus.yml
|
||||
|
||||
docker run -d --name prometheus --network promnet \
|
||||
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||
-p 9090:9090 \
|
||||
prom/prometheus
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
@@ -420,13 +480,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -447,11 +500,13 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags=benchmark \
|
||||
-run=^$ \
|
||||
-bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/...
|
||||
|
||||
api_integration_test:
|
||||
@@ -489,13 +544,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -505,89 +553,8 @@ jobs:
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=integration \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./management/...
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [ build-cache ]
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Generate Shared Sock Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||
|
||||
- name: Generate RouteManager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
|
||||
|
||||
- name: Generate SystemOps Test bin
|
||||
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
|
||||
|
||||
- name: Generate nftables Manager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||
|
||||
- name: Generate Engine Test bin
|
||||
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
||||
|
||||
- name: Generate Peer Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
|
||||
|
||||
- run: chmod +x *testing.bin
|
||||
|
||||
- name: Run Shared Sock tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Iface tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
|
||||
|
||||
- name: Run RouteManager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run SystemOps tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run nftables Manager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Engine tests in docker with file store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Engine tests in docker with sqlite store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Peer tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
|
||||
@@ -178,6 +178,7 @@ jobs:
|
||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -96,6 +96,20 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-upload
|
||||
dir: upload-server
|
||||
env: [CGO_ENABLED=0]
|
||||
binary: netbird-upload
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
@@ -409,6 +423,52 @@ dockers:
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
docker_manifests:
|
||||
- name_template: netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
@@ -475,7 +535,17 @@ docker_manifests:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- netbirdio/management:{{ .Version }}-debug-arm
|
||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||
- name_template: netbirdio/upload:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/upload:latest
|
||||
image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
brews:
|
||||
- ids:
|
||||
- default
|
||||
|
||||
20
README.md
20
README.md
@@ -57,16 +57,16 @@
|
||||
|
||||
### Key features
|
||||
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
|
||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||
| Connectivity | Management | Security | Automation| Platforms |
|
||||
|----|----|----|----|----|
|
||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
FROM alpine:3.21.3
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
|
||||
@@ -11,9 +11,12 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
const errCloseConnection = "Failed to close connection: %v"
|
||||
@@ -84,16 +87,27 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
}
|
||||
if debugUploadBundle {
|
||||
request.UploadURL = debugUploadBundleURL
|
||||
}
|
||||
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||
|
||||
cmd.Println(resp.GetPath())
|
||||
if resp.GetUploadFailureReason() != "" {
|
||||
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||
}
|
||||
|
||||
if debugUploadBundle {
|
||||
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -208,12 +222,15 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
request := &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
})
|
||||
}
|
||||
if debugUploadBundle {
|
||||
request.UploadURL = debugUploadBundleURL
|
||||
}
|
||||
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
@@ -239,7 +256,15 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Println(resp.GetPath())
|
||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||
|
||||
if resp.GetUploadFailureReason() != "" {
|
||||
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||
}
|
||||
|
||||
if debugUploadBundle {
|
||||
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -326,3 +351,34 @@ func formatDuration(d time.Duration) string {
|
||||
s := d / time.Second
|
||||
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
||||
}
|
||||
|
||||
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
|
||||
var networkMap *mgmProto.NetworkMap
|
||||
var err error
|
||||
|
||||
if connectClient != nil {
|
||||
networkMap, err = connectClient.GetLatestNetworkMap()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get latest network map: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(
|
||||
debug.GeneratorDependencies{
|
||||
InternalConfig: config,
|
||||
StatusRecorder: recorder,
|
||||
NetworkMap: networkMap,
|
||||
LogFile: logFilePath,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
IncludeSystemInfo: true,
|
||||
},
|
||||
)
|
||||
|
||||
path, err := bundleGenerator.Generate()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate debug bundle: %v", err)
|
||||
return
|
||||
}
|
||||
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
|
||||
}
|
||||
|
||||
39
client/cmd/debug_unix.go
Normal file
39
client/cmd/debug_unix.go
Normal file
@@ -0,0 +1,39 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
func SetupDebugHandler(
|
||||
ctx context.Context,
|
||||
config *internal.Config,
|
||||
recorder *peer.Status,
|
||||
connectClient *internal.ConnectClient,
|
||||
logFilePath string,
|
||||
) {
|
||||
usr1Ch := make(chan os.Signal, 1)
|
||||
|
||||
signal.Notify(usr1Ch, syscall.SIGUSR1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-usr1Ch:
|
||||
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
|
||||
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
126
client/cmd/debug_windows.go
Normal file
126
client/cmd/debug_windows.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
|
||||
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
|
||||
|
||||
waitTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
|
||||
// Example usage with PowerShell:
|
||||
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
|
||||
// $evt.Set()
|
||||
// $evt.Close()
|
||||
func SetupDebugHandler(
|
||||
ctx context.Context,
|
||||
config *internal.Config,
|
||||
recorder *peer.Status,
|
||||
connectClient *internal.ConnectClient,
|
||||
logFilePath string,
|
||||
) {
|
||||
env := os.Getenv(envListenEvent)
|
||||
if env == "" {
|
||||
return
|
||||
}
|
||||
|
||||
listenEvent, err := strconv.ParseBool(env)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
|
||||
return
|
||||
}
|
||||
if !listenEvent {
|
||||
return
|
||||
}
|
||||
|
||||
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: restrict access by ACL
|
||||
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
|
||||
if err != nil {
|
||||
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
|
||||
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
|
||||
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
|
||||
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
|
||||
} else {
|
||||
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if eventHandle == windows.InvalidHandle {
|
||||
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
|
||||
|
||||
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
|
||||
}
|
||||
|
||||
func waitForEvent(
|
||||
ctx context.Context,
|
||||
config *internal.Config,
|
||||
recorder *peer.Status,
|
||||
connectClient *internal.ConnectClient,
|
||||
logFilePath string,
|
||||
eventHandle windows.Handle,
|
||||
) {
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(eventHandle); err != nil {
|
||||
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
|
||||
|
||||
switch status {
|
||||
case windows.WAIT_OBJECT_0:
|
||||
log.Info("Received signal on debug event. Triggering debug bundle generation.")
|
||||
|
||||
// reset the event so it can be triggered again later (manual reset == 1)
|
||||
if err := windows.ResetEvent(eventHandle); err != nil {
|
||||
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
|
||||
}
|
||||
|
||||
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||
case uint32(windows.WAIT_TIMEOUT):
|
||||
|
||||
default:
|
||||
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,10 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
}
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
Use: "login",
|
||||
Short: "login to the Netbird Management Service (first run)",
|
||||
@@ -51,6 +55,9 @@ var loginCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
// update host's static platform and system information
|
||||
system.UpdateStaticInfo()
|
||||
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
@@ -127,7 +134,7 @@ var loginCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
@@ -198,7 +205,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||
@@ -212,19 +219,27 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||
var codeMsg string
|
||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
}
|
||||
|
||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
if noBrowser {
|
||||
cmd.Println("Use this URL to log in:\n\n" + verificationURIComplete + " " + codeMsg)
|
||||
} else {
|
||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
}
|
||||
|
||||
cmd.Println("")
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
|
||||
if !noBrowser {
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -39,6 +40,9 @@ const (
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
uploadBundle = "upload-bundle"
|
||||
uploadBundleURL = "upload-bundle-url"
|
||||
defaultBundleURL = "https://upload.debug.netbird.io" + types.GetURLPath
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -75,6 +79,8 @@ var (
|
||||
debugSystemInfoFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
blockLANAccess bool
|
||||
debugUploadBundle bool
|
||||
debugUploadBundleURL string
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
@@ -181,6 +187,8 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, defaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
|
||||
@@ -16,12 +16,17 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func (p *program) Start(svc service.Service) error {
|
||||
// Start should not block. Do the actual work async.
|
||||
log.Info("starting Netbird service") //nolint
|
||||
|
||||
// Collect static system and platform information
|
||||
system.UpdateStaticInfo()
|
||||
|
||||
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
||||
p.serv = grpc.NewServer()
|
||||
|
||||
@@ -115,6 +120,7 @@ var runCmd = &cobra.Command{
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
SetupDebugHandler(ctx, nil, nil, nil, logFile)
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
if err != nil {
|
||||
|
||||
@@ -12,9 +12,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -32,7 +34,7 @@ import (
|
||||
|
||||
func startTestingServices(t *testing.T) string {
|
||||
t.Helper()
|
||||
config := &mgmt.Config{}
|
||||
config := &types.Config{}
|
||||
_, err := util.ReadJson("../testdata/management.json", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -67,7 +69,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
return s, lis
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
@@ -90,13 +92,18 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -32,12 +32,16 @@ const (
|
||||
|
||||
const (
|
||||
dnsLabelsFlag = "extra-dns-labels"
|
||||
|
||||
noBrowserFlag = "no-browser"
|
||||
noBrowserDesc = "do not open the browser for SSO login"
|
||||
)
|
||||
|
||||
var (
|
||||
foregroundMode bool
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
noBrowser bool
|
||||
|
||||
upCmd = &cobra.Command{
|
||||
Use: "up",
|
||||
@@ -65,6 +69,9 @@ func init() {
|
||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||
`or --extra-dns-labels ""`,
|
||||
)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -212,6 +219,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil)
|
||||
}
|
||||
|
||||
@@ -349,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
|
||||
@@ -113,17 +113,16 @@ func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if !destination.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
@@ -243,6 +242,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -38,10 +38,12 @@ const (
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
matchSet = "--match-set"
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
matchSet = "--match-set"
|
||||
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
@@ -55,18 +57,18 @@ type ruleInfo struct {
|
||||
}
|
||||
|
||||
type routeFilteringRuleParams struct {
|
||||
Sources []netip.Prefix
|
||||
Destination netip.Prefix
|
||||
Source firewall.Network
|
||||
Destination firewall.Network
|
||||
Proto firewall.Protocol
|
||||
SPort *firewall.Port
|
||||
DPort *firewall.Port
|
||||
Direction firewall.RuleDirection
|
||||
Action firewall.Action
|
||||
SetName string
|
||||
}
|
||||
|
||||
type routeRules map[string][]string
|
||||
|
||||
// the ipset library currently does not support comments, so we use the name only (string)
|
||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
|
||||
type router struct {
|
||||
@@ -115,6 +117,10 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
@@ -123,7 +129,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
||||
func (r *router) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
@@ -134,27 +140,28 @@ func (r *router) AddRouteFiltering(
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
var setName string
|
||||
var source firewall.Network
|
||||
if len(sources) > 1 {
|
||||
setName = firewall.GenerateSetName(sources)
|
||||
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
source.Set = firewall.NewPrefixSet(sources)
|
||||
} else if len(sources) > 0 {
|
||||
source.Prefix = sources[0]
|
||||
}
|
||||
|
||||
params := routeFilteringRuleParams{
|
||||
Sources: sources,
|
||||
Source: source,
|
||||
Destination: destination,
|
||||
Proto: proto,
|
||||
SPort: sPort,
|
||||
DPort: dPort,
|
||||
Action: action,
|
||||
SetName: setName,
|
||||
}
|
||||
|
||||
rule := genRouteFilteringRuleSpec(params)
|
||||
rule, err := r.genRouteRuleSpec(params, sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate route rule spec: %w", err)
|
||||
}
|
||||
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
var err error
|
||||
if action == firewall.ActionDrop {
|
||||
// after the established rule
|
||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
||||
@@ -177,17 +184,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
ruleKey := rule.ID()
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
setName := r.findSetNameInRule(rule)
|
||||
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
return fmt.Errorf("delete route rule: %v", err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if setName != "" {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
return fmt.Errorf("failed to remove ipset: %w", err)
|
||||
}
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("route rule %s not found", ruleKey)
|
||||
@@ -198,13 +201,26 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) findSetNameInRule(rule []string) string {
|
||||
for i, arg := range rule {
|
||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||
return rule[i+3]
|
||||
func (r *router) decrementSetCounter(rule []string) error {
|
||||
sets := r.findSets(rule)
|
||||
var merr *multierror.Error
|
||||
for _, setName := range sets {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
|
||||
}
|
||||
}
|
||||
return ""
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) findSets(rule []string) []string {
|
||||
var sets []string
|
||||
for i, arg := range rule {
|
||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||
sets = append(sets, rule[i+3])
|
||||
}
|
||||
}
|
||||
return sets
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
@@ -225,6 +241,8 @@ func (r *router) deleteIpSet(setName string) error {
|
||||
if err := ipset.Destroy(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
log.Debugf("Deleted unused ipset %s", setName)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -264,12 +282,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
@@ -307,8 +327,10 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("legacy forwarding rule %s not found", ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -348,12 +370,16 @@ func (r *router) Reset() error {
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
r.rules = make(map[string][]string)
|
||||
|
||||
if err := r.ipsetCounter.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.cleanupDataPlaneMark(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
r.rules = make(map[string][]string)
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
@@ -423,6 +449,57 @@ func (r *router) createContainers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDataPlaneMark configures the fwmark for the data plane
|
||||
func (r *router) setupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
preRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePre] = preRule
|
||||
}
|
||||
|
||||
postRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePost] = postRule
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) cleanupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
if preRule, exists := r.rules[markManglePre]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePre)
|
||||
}
|
||||
}
|
||||
|
||||
if postRule, exists := r.rules[markManglePost]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePost)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) addPostroutingRules() error {
|
||||
// First rule for outbound masquerade
|
||||
rule1 := []string{
|
||||
@@ -464,7 +541,7 @@ func (r *router) insertEstablishedRule(chain string) error {
|
||||
}
|
||||
|
||||
func (r *router) addJumpRules() error {
|
||||
// Jump to NAT chain
|
||||
// Jump to nat chain
|
||||
natRule := []string{"-j", chainRTNAT}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
||||
@@ -538,12 +615,26 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
rule = append(rule,
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
"-s", pair.Source.String(),
|
||||
"-d", pair.Destination.String(),
|
||||
)
|
||||
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply network -s: %w", err)
|
||||
}
|
||||
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply network -d: %w", err)
|
||||
}
|
||||
|
||||
rule = append(rule, sourceExp...)
|
||||
rule = append(rule, destExp...)
|
||||
rule = append(rule,
|
||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||
)
|
||||
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
// Ensure nat rules come first, so the mark can be overwritten.
|
||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
|
||||
// TODO: rollback ipset counter
|
||||
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
@@ -561,6 +652,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("marking rule %s not found", ruleKey)
|
||||
}
|
||||
@@ -726,17 +821,21 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) {
|
||||
var rule []string
|
||||
|
||||
if params.SetName != "" {
|
||||
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
|
||||
} else if len(params.Sources) > 0 {
|
||||
source := params.Sources[0]
|
||||
rule = append(rule, "-s", source.String())
|
||||
sourceExp, err := r.applyNetwork("-s", params.Source, sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply network -s: %w", err)
|
||||
|
||||
}
|
||||
destExp, err := r.applyNetwork("-d", params.Destination, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply network -d: %w", err)
|
||||
}
|
||||
|
||||
rule = append(rule, "-d", params.Destination.String())
|
||||
rule = append(rule, sourceExp...)
|
||||
rule = append(rule, destExp...)
|
||||
|
||||
if params.Proto != firewall.ProtocolALL {
|
||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||
@@ -746,7 +845,47 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
|
||||
rule = append(rule, "-j", actionToStr(params.Action))
|
||||
|
||||
return rule
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||
direction := "src"
|
||||
if flag == "-d" {
|
||||
direction = "dst"
|
||||
}
|
||||
|
||||
if network.IsSet() {
|
||||
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
|
||||
}
|
||||
if network.IsPrefix() {
|
||||
return []string{flag, network.Prefix.String()}, nil
|
||||
}
|
||||
|
||||
// nolint:nilnil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
|
||||
@@ -46,7 +46,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
// 5. jump rule to PRE nat chain
|
||||
// 6. static outbound masquerade rule
|
||||
// 7. static return masquerade rule
|
||||
require.Len(t, manager.rules, 7, "should have created rules map")
|
||||
// 8. mangle prerouting mark rule
|
||||
// 9. mangle postrouting mark rule
|
||||
require.Len(t, manager.rules, 9, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
@@ -58,8 +60,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "abc",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")},
|
||||
Masquerade: true,
|
||||
}
|
||||
|
||||
@@ -330,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
@@ -345,23 +347,29 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
assert.NoError(t, err, "Failed to check rule existence")
|
||||
assert.True(t, exists, "Rule not found in iptables")
|
||||
|
||||
var source firewall.Network
|
||||
if len(tt.sources) > 1 {
|
||||
source.Set = firewall.NewPrefixSet(tt.sources)
|
||||
} else if len(tt.sources) > 0 {
|
||||
source.Prefix = tt.sources[0]
|
||||
}
|
||||
// Verify rule content
|
||||
params := routeFilteringRuleParams{
|
||||
Sources: tt.sources,
|
||||
Destination: tt.destination,
|
||||
Source: source,
|
||||
Destination: firewall.Network{Prefix: tt.destination},
|
||||
Proto: tt.proto,
|
||||
SPort: tt.sPort,
|
||||
DPort: tt.dPort,
|
||||
Action: tt.action,
|
||||
SetName: "",
|
||||
}
|
||||
|
||||
expectedRule := genRouteFilteringRuleSpec(params)
|
||||
expectedRule, err := r.genRouteRuleSpec(params, nil)
|
||||
require.NoError(t, err, "Failed to generate expected rule spec")
|
||||
|
||||
if tt.expectSet {
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
params.SetName = setName
|
||||
expectedRule = genRouteFilteringRuleSpec(params)
|
||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||
expectedRule, err = r.genRouteRuleSpec(params, nil)
|
||||
require.NoError(t, err, "Failed to generate expected rule spec with set")
|
||||
|
||||
// Check if the set was created
|
||||
_, exists := r.ipsetCounter.Get(setName)
|
||||
@@ -376,3 +384,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSetNameInRule(t *testing.T) {
|
||||
r := &router{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
rule []string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Basic rule with two sets",
|
||||
rule: []string{
|
||||
"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src",
|
||||
"-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT",
|
||||
},
|
||||
expected: []string{"nb-2e5a2a05", "nb-349ae051"},
|
||||
},
|
||||
{
|
||||
name: "No sets",
|
||||
rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"},
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Multiple sets with different positions",
|
||||
rule: []string{
|
||||
"-m", "set", "--match-set", "set1", "src", "-p", "tcp",
|
||||
"-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT",
|
||||
},
|
||||
expected: []string{"set1", "set-abc123"},
|
||||
},
|
||||
{
|
||||
name: "Boundary case - sequence appears at end",
|
||||
rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"},
|
||||
expected: []string{"final-set"},
|
||||
},
|
||||
{
|
||||
name: "Incomplete pattern - missing set name",
|
||||
rule: []string{"-p", "tcp", "-m", "set", "--match-set"},
|
||||
expected: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := r.findSets(tc.rule)
|
||||
|
||||
if len(result) != len(tc.expected) {
|
||||
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
||||
return
|
||||
}
|
||||
|
||||
for i, set := range result {
|
||||
if set != tc.expected[i] {
|
||||
t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -43,6 +40,18 @@ const (
|
||||
// Action is the action to be taken on a rule
|
||||
type Action int
|
||||
|
||||
// String returns the string representation of the action
|
||||
func (a Action) String() string {
|
||||
switch a {
|
||||
case ActionAccept:
|
||||
return "accept"
|
||||
case ActionDrop:
|
||||
return "drop"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// ActionAccept is the action to accept a packet
|
||||
ActionAccept Action = iota
|
||||
@@ -50,6 +59,33 @@ const (
|
||||
ActionDrop
|
||||
)
|
||||
|
||||
// Network is a rule destination, either a set or a prefix
|
||||
type Network struct {
|
||||
Set Set
|
||||
Prefix netip.Prefix
|
||||
}
|
||||
|
||||
// String returns the string representation of the destination
|
||||
func (d Network) String() string {
|
||||
if d.Prefix.IsValid() {
|
||||
return d.Prefix.String()
|
||||
}
|
||||
if d.IsSet() {
|
||||
return d.Set.HashedName()
|
||||
}
|
||||
return "<invalid network>"
|
||||
}
|
||||
|
||||
// IsSet returns true if the destination is a set
|
||||
func (d Network) IsSet() bool {
|
||||
return d.Set != Set{}
|
||||
}
|
||||
|
||||
// IsPrefix returns true if the destination is a valid prefix
|
||||
func (d Network) IsPrefix() bool {
|
||||
return d.Prefix.IsValid()
|
||||
}
|
||||
|
||||
// Manager is the high level abstraction of a firewall manager
|
||||
//
|
||||
// It declares methods which handle actions required by the
|
||||
@@ -83,10 +119,9 @@ type Manager interface {
|
||||
AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination Network,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
sPort, dPort *Port,
|
||||
action Action,
|
||||
) (Rule, error)
|
||||
|
||||
@@ -119,6 +154,9 @@ type Manager interface {
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
DeleteDNATRule(Rule) error
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
@@ -153,22 +191,6 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
||||
func GenerateSetName(sources []netip.Prefix) string {
|
||||
// sort for consistent naming
|
||||
SortPrefixes(sources)
|
||||
|
||||
var sourcesStr strings.Builder
|
||||
for _, src := range sources {
|
||||
sourcesStr.WriteString(src.String())
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(sourcesStr.String()))
|
||||
shortHash := hex.EncodeToString(hash[:])[:8]
|
||||
|
||||
return fmt.Sprintf("nb-%s", shortHash)
|
||||
}
|
||||
|
||||
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
||||
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
if len(prefixes) == 0 {
|
||||
|
||||
@@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) {
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
result1 := manager.NewPrefixSet(prefixes1)
|
||||
result2 := manager.NewPrefixSet(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
||||
@@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) {
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
|
||||
result := manager.GenerateSetName(prefixes)
|
||||
result := manager.NewPrefixSet(prefixes)
|
||||
|
||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
|
||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName())
|
||||
if err != nil {
|
||||
t.Fatalf("Error matching regex: %v", err)
|
||||
}
|
||||
@@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
||||
result1 := manager.GenerateSetName([]netip.Prefix{})
|
||||
result2 := manager.GenerateSetName([]netip.Prefix{})
|
||||
result1 := manager.NewPrefixSet([]netip.Prefix{})
|
||||
result2 := manager.NewPrefixSet([]netip.Prefix{})
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
||||
@@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) {
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
result1 := manager.NewPrefixSet(prefixes1)
|
||||
result2 := manager.NewPrefixSet(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type RouterPair struct {
|
||||
ID route.ID
|
||||
Source netip.Prefix
|
||||
Destination netip.Prefix
|
||||
Source Network
|
||||
Destination Network
|
||||
Masquerade bool
|
||||
Inverse bool
|
||||
}
|
||||
|
||||
74
client/firewall/manager/set.go
Normal file
74
client/firewall/manager/set.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
type Set struct {
|
||||
hash [4]byte
|
||||
comment string
|
||||
}
|
||||
|
||||
// String returns the string representation of the set: hashed name and comment
|
||||
func (h Set) String() string {
|
||||
if h.comment == "" {
|
||||
return h.HashedName()
|
||||
}
|
||||
return h.HashedName() + ": " + h.comment
|
||||
}
|
||||
|
||||
// HashedName returns the string representation of the hash
|
||||
func (h Set) HashedName() string {
|
||||
return fmt.Sprintf(
|
||||
"nb-%s",
|
||||
hex.EncodeToString(h.hash[:]),
|
||||
)
|
||||
}
|
||||
|
||||
// Comment returns the comment of the set
|
||||
func (h Set) Comment() string {
|
||||
return h.comment
|
||||
}
|
||||
|
||||
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
||||
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
||||
// sort for consistent naming
|
||||
SortPrefixes(prefixes)
|
||||
|
||||
hash := sha256.New()
|
||||
for _, src := range prefixes {
|
||||
bytes, err := src.MarshalBinary()
|
||||
if err != nil {
|
||||
log.Warnf("failed to marshal prefix %s: %v", src, err)
|
||||
}
|
||||
hash.Write(bytes)
|
||||
}
|
||||
var set Set
|
||||
copy(set.hash[:], hash.Sum(nil)[:4])
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
// NewDomainSet generates a unique name for an ipset based on the given domains.
|
||||
func NewDomainSet(domains domain.List) Set {
|
||||
slices.Sort(domains)
|
||||
|
||||
hash := sha256.New()
|
||||
for _, d := range domains {
|
||||
hash.Write([]byte(d.PunycodeString()))
|
||||
}
|
||||
set := Set{
|
||||
comment: domains.SafeString(),
|
||||
}
|
||||
copy(set.hash[:], hash.Sum(nil)[:4])
|
||||
|
||||
return set
|
||||
}
|
||||
@@ -25,9 +25,10 @@ const (
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
|
||||
// filter chains contains the rules that jump to the rules chains
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNamePrerouting = "netbird-rt-prerouting"
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
|
||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
@@ -462,13 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
// Chain is created by route manager
|
||||
// TODO: move creation to a common place
|
||||
m.chainPrerouting = &nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
Table: m.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
})
|
||||
}
|
||||
|
||||
m.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
|
||||
@@ -135,17 +135,16 @@ func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if !destination.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
@@ -242,7 +241,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
// Close closes the firewall manager
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -359,6 +358,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
||||
@@ -289,7 +289,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
@@ -298,8 +298,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
pair := fw.RouterPair{
|
||||
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||
Masquerade: true,
|
||||
}
|
||||
err = manager.AddNatRule(pair)
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
@@ -44,9 +43,14 @@ const (
|
||||
const refreshRulesMapError = "refresh rules map: %w"
|
||||
|
||||
var (
|
||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
||||
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||
)
|
||||
|
||||
type setInput struct {
|
||||
set firewall.Set
|
||||
prefixes []netip.Prefix
|
||||
}
|
||||
|
||||
type router struct {
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
@@ -54,7 +58,7 @@ type router struct {
|
||||
chains map[string]*nftables.Chain
|
||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||
rules map[string]*nftables.Rule
|
||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
||||
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||
|
||||
wgIface iFaceMapper
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
@@ -100,6 +104,10 @@ func (r *router) init(workTable *nftables.Table) error {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -159,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error {
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
||||
return nil, fmt.Errorf("unable to list tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
@@ -196,15 +204,21 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
// Chain is created by acl manager
|
||||
// TODO: move creation to a common place
|
||||
r.chains[chainNamePrerouting] = &nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameManglePostrouting,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
}
|
||||
})
|
||||
|
||||
// Add the single NAT rule that matches on mark
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
@@ -220,7 +234,83 @@ func (r *router) createContainers() error {
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
||||
return fmt.Errorf("initialize tables: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDataPlaneMark configures the fwmark for the data plane
|
||||
func (r *router) setupDataPlaneMark() error {
|
||||
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
|
||||
return errors.New("no mangle chains found")
|
||||
}
|
||||
|
||||
ctNew := getCtNewExprs()
|
||||
preExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
preExprs = append(preExprs, ctNew...)
|
||||
preExprs = append(preExprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
preNftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: preExprs,
|
||||
}
|
||||
r.conn.AddRule(preNftRule)
|
||||
|
||||
postExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
postExprs = append(postExprs, ctNew...)
|
||||
postExprs = append(postExprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
postNftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePostrouting],
|
||||
Exprs: postExprs,
|
||||
}
|
||||
r.conn.AddRule(postNftRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -230,7 +320,7 @@ func (r *router) createContainers() error {
|
||||
func (r *router) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
@@ -245,23 +335,29 @@ func (r *router) AddRouteFiltering(
|
||||
chain := r.chains[chainNameRoutingFw]
|
||||
var exprs []expr.Any
|
||||
|
||||
var source firewall.Network
|
||||
switch {
|
||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||
// If it's 0.0.0.0/0, we don't need to add any source matching
|
||||
case len(sources) == 1:
|
||||
// If there's only one source, we can use it directly
|
||||
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
|
||||
source.Prefix = sources[0]
|
||||
default:
|
||||
// If there are multiple sources, create or get an ipset
|
||||
var err error
|
||||
exprs, err = r.getIpSetExprs(sources, exprs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get ipset expressions: %w", err)
|
||||
}
|
||||
// If there are multiple sources, use a set
|
||||
source.Set = firewall.NewPrefixSet(sources)
|
||||
}
|
||||
|
||||
// Handle destination
|
||||
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
|
||||
sourceExp, err := r.applyNetwork(source, sources, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
exprs = append(exprs, sourceExp...)
|
||||
|
||||
destExp, err := r.applyNetwork(destination, nil, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
exprs = append(exprs, destExp...)
|
||||
|
||||
// Handle protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
@@ -305,39 +401,27 @@ func (r *router) AddRouteFiltering(
|
||||
rule = r.conn.AddRule(rule)
|
||||
}
|
||||
|
||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
r.rules[string(ruleKey)] = rule
|
||||
|
||||
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
||||
log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
||||
setName := firewall.GenerateSetName(sources)
|
||||
ref, err := r.ipsetCounter.Increment(setName, sources)
|
||||
func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
|
||||
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
|
||||
set: set,
|
||||
prefixes: prefixes,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ref.Out.Name,
|
||||
SetID: ref.Out.ID,
|
||||
},
|
||||
)
|
||||
return exprs, nil
|
||||
return getIpSetExprs(ref, isSource)
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
@@ -356,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||
}
|
||||
|
||||
setName := r.findSetNameInRule(nftRule)
|
||||
|
||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||
return fmt.Errorf("delete: %w", err)
|
||||
}
|
||||
|
||||
if setName != "" {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
return fmt.Errorf("decrement ipset reference: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
|
||||
func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) {
|
||||
// overlapping prefixes will result in an error, so we need to merge them
|
||||
sources = firewall.MergeIPRanges(sources)
|
||||
prefixes := firewall.MergeIPRanges(input.prefixes)
|
||||
|
||||
set := &nftables.Set{
|
||||
Name: setName,
|
||||
Table: r.workTable,
|
||||
nfset := &nftables.Set{
|
||||
Name: setName,
|
||||
Comment: input.set.Comment(),
|
||||
Table: r.workTable,
|
||||
// required for prefixes
|
||||
Interval: true,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
}
|
||||
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
if err := r.conn.AddSet(nfset, elements); err != nil {
|
||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||
|
||||
return nfset, nil
|
||||
}
|
||||
|
||||
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
var elements []nftables.SetElement
|
||||
for _, prefix := range sources {
|
||||
for _, prefix := range prefixes {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -407,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.
|
||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||
)
|
||||
}
|
||||
|
||||
if err := r.conn.AddSet(set, elements); err != nil {
|
||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||
|
||||
return set, nil
|
||||
return elements
|
||||
}
|
||||
|
||||
// calculateLastIP determines the last IP in a given prefix.
|
||||
@@ -442,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte {
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
||||
r.conn.DelSet(set)
|
||||
func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
|
||||
r.conn.DelSet(nfset)
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
@@ -452,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
|
||||
for _, e := range rule.Exprs {
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
return lookup.SetName
|
||||
func (r *router) decrementSetCounter(rule *nftables.Rule) error {
|
||||
sets := r.findSets(rule)
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, setName := range sets {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
|
||||
}
|
||||
}
|
||||
return ""
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) findSets(rule *nftables.Rule) []string {
|
||||
var sets []string
|
||||
for _, e := range rule.Exprs {
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
sets = append(sets, lookup.SetName)
|
||||
}
|
||||
}
|
||||
return sets
|
||||
}
|
||||
|
||||
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||
@@ -500,7 +599,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
|
||||
// TODO: rollback ipset counter
|
||||
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -508,8 +608,15 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
|
||||
op := expr.CmpOpEq
|
||||
if pair.Inverse {
|
||||
@@ -517,26 +624,6 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
|
||||
// interface matching
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
@@ -547,6 +634,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
exprs = append(exprs, getCtNewExprs()...)
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
@@ -576,9 +666,11 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
// Ensure nat rules come first, so the mark can be overwritten.
|
||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNamePrerouting],
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
@@ -659,8 +751,15 @@ func (r *router) addPostroutingRules() error {
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Counter{},
|
||||
@@ -669,7 +768,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
},
|
||||
}
|
||||
|
||||
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
@@ -682,7 +782,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Exprs: expression,
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
@@ -697,11 +797,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
} else {
|
||||
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -912,12 +1014,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
@@ -925,10 +1029,10 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
||||
// TODO: rollback set counter
|
||||
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -936,16 +1040,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
err := r.conn.DelRule(rule)
|
||||
if err != nil {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
|
||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -957,7 +1064,7 @@ func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
||||
return fmt.Errorf(" unable to list rules: %v", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
@@ -1231,13 +1338,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
||||
var offset uint32
|
||||
if source {
|
||||
offset = 12 // src offset
|
||||
} else {
|
||||
offset = 16 // dst offset
|
||||
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
|
||||
if err != nil {
|
||||
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
|
||||
elements := convertPrefixesToSet(prefixes)
|
||||
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
||||
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||
func (r *router) applyNetwork(
|
||||
network firewall.Network,
|
||||
setPrefixes []netip.Prefix,
|
||||
isSource bool,
|
||||
) ([]expr.Any, error) {
|
||||
if network.IsSet() {
|
||||
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("source: %w", err)
|
||||
}
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
if network.IsPrefix() {
|
||||
return applyPrefix(network.Prefix, isSource), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// applyPrefix generates nftables expressions for a CIDR prefix
|
||||
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||
// dst offset
|
||||
offset := uint32(16)
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = 12
|
||||
}
|
||||
|
||||
ones := prefix.Bits()
|
||||
@@ -1324,3 +1472,48 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
|
||||
return exprs
|
||||
}
|
||||
|
||||
func getCtNewExprs() []expr.Any {
|
||||
return []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||
|
||||
// dst offset
|
||||
offset := uint32(16)
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = 12
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ref.Out.Name,
|
||||
SetID: ref.Out.ID,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
|
||||
// Build CIDR matching expressions
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||
|
||||
// Combine all expressions in the correct order
|
||||
// nolint:gocritic
|
||||
@@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := 0
|
||||
for _, chain := range rtr.chains {
|
||||
if chain.Name == chainNamePrerouting {
|
||||
if chain.Name == chainNameManglePrerouting {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
@@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
// Verify the rule was added
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := false
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err, "should list rules")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
@@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
// Verify the rule was removed
|
||||
found = false
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err, "should list rules after removal")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
t.Cleanup(func() {
|
||||
@@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
set, err := r.createIpSet(setName, tt.sources)
|
||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
|
||||
if err != nil {
|
||||
t.Logf("Failed to create IP set: %v", err)
|
||||
printNftSets()
|
||||
|
||||
@@ -15,8 +15,8 @@ var (
|
||||
Name: "Insert Forwarding IPV4 Rule",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||
Masquerade: false,
|
||||
},
|
||||
},
|
||||
@@ -24,8 +24,8 @@ var (
|
||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
@@ -40,8 +40,8 @@ var (
|
||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -22,7 +21,7 @@ const (
|
||||
firewallRuleName = "Netbird"
|
||||
)
|
||||
|
||||
// Reset firewall to the default state
|
||||
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||
func (m *Manager) Close(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -32,17 +31,14 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -12,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
|
||||
@@ -189,7 +189,7 @@ func (t *ICMPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
@@ -23,11 +23,11 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
TCPSyn uint8 = 0x02
|
||||
TCPAck uint8 = 0x10
|
||||
TCPFin uint8 = 0x01
|
||||
TCPSyn uint8 = 0x02
|
||||
TCPRst uint8 = 0x04
|
||||
TCPPush uint8 = 0x08
|
||||
TCPAck uint8 = 0x10
|
||||
TCPUrg uint8 = 0x20
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ const (
|
||||
)
|
||||
|
||||
// TCPState represents the state of a TCP connection
|
||||
type TCPState int
|
||||
type TCPState int32
|
||||
|
||||
func (s TCPState) String() string {
|
||||
switch s {
|
||||
@@ -89,22 +89,25 @@ const (
|
||||
// TCPConnTrack represents a TCP connection state
|
||||
type TCPConnTrack struct {
|
||||
BaseConnTrack
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
State TCPState
|
||||
established atomic.Bool
|
||||
tombstone atomic.Bool
|
||||
sync.RWMutex
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
state atomic.Int32
|
||||
tombstone atomic.Bool
|
||||
}
|
||||
|
||||
// IsEstablished safely checks if connection is established
|
||||
func (t *TCPConnTrack) IsEstablished() bool {
|
||||
return t.established.Load()
|
||||
// GetState safely retrieves the current state
|
||||
func (t *TCPConnTrack) GetState() TCPState {
|
||||
return TCPState(t.state.Load())
|
||||
}
|
||||
|
||||
// SetEstablished safely sets the established state
|
||||
func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||
t.established.Store(state)
|
||||
// SetState safely updates the current state
|
||||
func (t *TCPConnTrack) SetState(state TCPState) {
|
||||
t.state.Store(int32(state))
|
||||
}
|
||||
|
||||
// CompareAndSwapState atomically changes the state from old to new if current == old
|
||||
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
|
||||
return t.state.CompareAndSwap(int32(old), int32(newState))
|
||||
}
|
||||
|
||||
// IsTombstone safely checks if the connection is marked for deletion
|
||||
@@ -125,13 +128,17 @@ type TCPTracker struct {
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
waitTimeout time.Duration
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||
waitTimeout := TimeWaitTimeout
|
||||
if timeout == 0 {
|
||||
timeout = DefaultTCPTimeout
|
||||
} else {
|
||||
waitTimeout = timeout / 45
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -142,6 +149,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
waitTimeout: waitTimeout,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
@@ -149,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
@@ -162,12 +170,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
conn.Lock()
|
||||
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
||||
conn.Unlock()
|
||||
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
return key, true
|
||||
}
|
||||
|
||||
@@ -175,7 +178,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound TCP connection
|
||||
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||
@@ -183,14 +186,14 @@ func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort u
|
||||
}
|
||||
|
||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
if exists {
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -205,12 +208,11 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
DestPort: dstPort,
|
||||
}
|
||||
|
||||
conn.established.Store(false)
|
||||
conn.tombstone.Store(false)
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
|
||||
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||
t.updateState(key, conn, flags, direction == nftypes.Egress)
|
||||
conn.UpdateCounters(direction, size)
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections[key] = conn
|
||||
@@ -220,7 +222,7 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
|
||||
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
@@ -232,134 +234,125 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
if !exists || conn.IsTombstone() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle RST flag specially - it always causes transition to closed
|
||||
if flags&TCPRst != 0 {
|
||||
return t.handleRst(key, conn, size)
|
||||
currentState := conn.GetState()
|
||||
if !t.isValidStateForFlags(currentState, flags) {
|
||||
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
// allow all flags for established for now
|
||||
if currentState == TCPStateEstablished {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
conn.Lock()
|
||||
t.updateState(key, conn, flags, false)
|
||||
isEstablished := conn.IsEstablished()
|
||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||
conn.Unlock()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
return isEstablished || isValidState
|
||||
}
|
||||
|
||||
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, size int) bool {
|
||||
if conn.IsTombstone() {
|
||||
return true
|
||||
}
|
||||
|
||||
conn.Lock()
|
||||
conn.SetTombstone()
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
conn.Unlock()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
t.logger.Trace("TCP connection reset: %s", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.updateState(key, conn, flags, nftypes.Ingress, size)
|
||||
return true
|
||||
}
|
||||
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(packetDir, size)
|
||||
|
||||
state := conn.State
|
||||
defer func() {
|
||||
if state != conn.State {
|
||||
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
|
||||
currentState := conn.GetState()
|
||||
|
||||
if flags&TCPRst != 0 {
|
||||
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
switch state {
|
||||
var newState TCPState
|
||||
switch currentState {
|
||||
case TCPStateNew:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
conn.State = TCPStateSynSent
|
||||
if conn.Direction == nftypes.Egress {
|
||||
newState = TCPStateSynSent
|
||||
} else {
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateSynSent:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if isOutbound {
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
if packetDir != conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
} else {
|
||||
// Simultaneous open
|
||||
conn.State = TCPStateSynReceived
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateSynReceived:
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPFin != 0 {
|
||||
if isOutbound {
|
||||
conn.State = TCPStateFinWait1
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateFinWait1
|
||||
} else {
|
||||
conn.State = TCPStateCloseWait
|
||||
newState = TCPStateCloseWait
|
||||
}
|
||||
conn.SetEstablished(false)
|
||||
} else if flags&TCPRst != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
case TCPStateFinWait1:
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
conn.State = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPAck != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPRst != 0:
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
if packetDir != conn.Direction {
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
newState = TCPStateFinWait2
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateFinWait2:
|
||||
if flags&TCPFin != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
|
||||
t.logger.Trace("TCP connection %s completed", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
case TCPStateClosing:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
// Keep established = false from previous state
|
||||
|
||||
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
if flags&TCPFin != 0 {
|
||||
conn.State = TCPStateLastAck
|
||||
newState = TCPStateLastAck
|
||||
}
|
||||
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
newState = TCPStateClosed
|
||||
}
|
||||
}
|
||||
|
||||
// Send close event for gracefully closed connections
|
||||
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
|
||||
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
|
||||
|
||||
switch newState {
|
||||
case TCPStateTimeWait:
|
||||
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.logger.Trace("TCP connection %s closed gracefully", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -369,18 +362,22 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
if flags&TCPRst != 0 {
|
||||
if state == TCPStateSynSent {
|
||||
return flags&TCPAck != 0
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
switch state {
|
||||
case TCPStateNew:
|
||||
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
||||
case TCPStateSynSent:
|
||||
// TODO: support simultaneous open
|
||||
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
||||
case TCPStateSynReceived:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPRst != 0 {
|
||||
return true
|
||||
}
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateFinWait1:
|
||||
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||
@@ -397,9 +394,7 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
case TCPStateLastAck:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateClosed:
|
||||
// Accept retransmitted ACKs in closed state
|
||||
// This is important because the final ACK might be lost
|
||||
// and the peer will retransmit their FIN-ACK
|
||||
// Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK
|
||||
return flags&TCPAck != 0
|
||||
}
|
||||
return false
|
||||
@@ -430,23 +425,24 @@ func (t *TCPTracker) cleanup() {
|
||||
}
|
||||
|
||||
var timeout time.Duration
|
||||
switch {
|
||||
case conn.State == TCPStateTimeWait:
|
||||
timeout = TimeWaitTimeout
|
||||
case conn.IsEstablished():
|
||||
currentState := conn.GetState()
|
||||
switch currentState {
|
||||
case TCPStateTimeWait:
|
||||
timeout = t.waitTimeout
|
||||
case TCPStateEstablished:
|
||||
timeout = t.timeout
|
||||
default:
|
||||
timeout = TCPHandshakeTimeout
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(timeout) {
|
||||
// Return IPs to pool
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
|
||||
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
|
||||
// event already handled by state change
|
||||
if conn.State != TCPStateTimeWait {
|
||||
if currentState != TCPStateTimeWait {
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.cleanup()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -124,9 +125,6 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
// Receive RST
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.True(t, valid, "RST should be allowed for established connection")
|
||||
|
||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||
// The connection will be cleaned up by timeout
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -217,97 +215,446 @@ func TestRSTHandling(t *testing.T) {
|
||||
conn := tracker.connections[key]
|
||||
if tt.wantValid {
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateClosed, conn.State)
|
||||
require.False(t, conn.IsEstablished())
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPRetransmissions(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test SYN retransmission
|
||||
t.Run("SYN Retransmission", func(t *testing.T) {
|
||||
// Initial SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Retransmit SYN (should not affect the state machine)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Verify we're still in SYN-SENT state
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState())
|
||||
|
||||
// Complete the handshake
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// Verify we're in ESTABLISHED state
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
})
|
||||
|
||||
// Test ACK retransmission in established state
|
||||
t.Run("ACK Retransmission", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Retransmit ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// State should remain ESTABLISHED
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
})
|
||||
|
||||
// Test FIN retransmission
|
||||
t.Run("FIN Retransmission", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Retransmit FIN (should not change state)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPDataTransfer(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Data Transfer", func(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||
|
||||
// Receive ACK for data
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
|
||||
require.True(t, valid)
|
||||
|
||||
// Receive data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
|
||||
require.True(t, valid)
|
||||
|
||||
// Send ACK for received data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
|
||||
// State should remain ESTABLISHED
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
|
||||
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
|
||||
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
|
||||
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPHalfClosedConnections(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test half-closed connection: local end closes, remote end continues sending data
|
||||
t.Run("Local Close, Remote Data", func(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
|
||||
// Remote end can still send data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
|
||||
require.True(t, valid)
|
||||
|
||||
// We can still ACK their data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// Receive FIN from remote end
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// State should remain TIME-WAIT (waiting for possible retransmissions)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
})
|
||||
|
||||
// Test half-closed connection: remote end closes, local end continues sending data
|
||||
t.Run("Remote Close, Local Data", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Receive FIN from remote
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||
|
||||
// We can still send data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||
|
||||
// Remote can still ACK our data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
// Send our FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Receive final ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPAbnormalSequences(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test handling of unsolicited RST in various states
|
||||
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
|
||||
// Send SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Receive unsolicited RST (without proper ACK)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
|
||||
|
||||
// Receive RST with proper ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
require.True(t, conn.IsTombstone())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPTimeoutHandling(t *testing.T) {
|
||||
// Create tracker with a very short timeout for testing
|
||||
shortTimeout := 100 * time.Millisecond
|
||||
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Connection Timeout", func(t *testing.T) {
|
||||
// Establish a connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Wait for the connection to timeout
|
||||
time.Sleep(2 * shortTimeout)
|
||||
|
||||
// Force cleanup
|
||||
tracker.cleanup()
|
||||
|
||||
// Connection should be removed
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "Connection should be removed after timeout")
|
||||
})
|
||||
|
||||
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Complete the connection close to enter TIME_WAIT
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// TIME_WAIT should have its own timeout value (usually 2*MSL)
|
||||
// For the test, we're using a short timeout
|
||||
time.Sleep(2 * shortTimeout)
|
||||
|
||||
tracker.cleanup()
|
||||
|
||||
// Connection should be removed
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSynFlood(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
basePort := uint16(10000)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Create a large number of SYN packets to simulate a SYN flood
|
||||
for i := uint16(0); i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Check that we're tracking all connections
|
||||
require.Equal(t, 1000, len(tracker.connections))
|
||||
|
||||
// Now simulate SYN timeout
|
||||
var oldConns int
|
||||
tracker.mutex.Lock()
|
||||
for _, conn := range tracker.connections {
|
||||
if conn.GetState() == TCPStateSynSent {
|
||||
// Make the connection appear old
|
||||
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
|
||||
oldConns++
|
||||
}
|
||||
}
|
||||
tracker.mutex.Unlock()
|
||||
require.Equal(t, 1000, oldConns)
|
||||
|
||||
// Run cleanup
|
||||
tracker.cleanup()
|
||||
|
||||
// Check that stale connections were cleaned up
|
||||
require.Equal(t, 0, len(tracker.connections))
|
||||
}
|
||||
|
||||
func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
clientIP := netip.MustParseAddr("100.64.0.1")
|
||||
serverIP := netip.MustParseAddr("100.64.0.2")
|
||||
clientPort := uint16(12345)
|
||||
serverPort := uint16(80)
|
||||
|
||||
// 1. Client sends SYN (we receive it as inbound)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: clientIP,
|
||||
DstIP: serverIP,
|
||||
SrcPort: clientPort,
|
||||
DstPort: serverPort,
|
||||
}
|
||||
|
||||
tracker.mutex.RLock()
|
||||
conn := tracker.connections[key]
|
||||
tracker.mutex.RUnlock()
|
||||
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
|
||||
|
||||
// 2. Server sends SYN-ACK response
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
|
||||
// 3. Client sends ACK to complete handshake
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
|
||||
|
||||
// 4. Test data transfer
|
||||
// Client sends data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
|
||||
|
||||
// Server sends ACK for data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||
|
||||
// Server sends data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
|
||||
|
||||
// Client sends ACK for data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
|
||||
// Verify state and counters
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
|
||||
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
|
||||
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
|
||||
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
|
||||
}
|
||||
|
||||
// Helper to establish a TCP connection
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||
t.Helper()
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
}
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.cleanup()
|
||||
}
|
||||
})
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ func (t *UDPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
@@ -17,6 +19,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
@@ -29,8 +32,10 @@ const (
|
||||
)
|
||||
|
||||
type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
// ruleIdMap is used to store the rule ID for a given connection
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
@@ -167,3 +172,35 @@ func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
}
|
||||
return addr.AsSlice()
|
||||
}
|
||||
|
||||
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
|
||||
key := buildKey(srcIP, dstIP, srcPort, dstPort)
|
||||
f.ruleIdMap.LoadOrStore(key, ruleID)
|
||||
}
|
||||
|
||||
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
|
||||
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||
return value.([]byte), true
|
||||
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
||||
return value.([]byte), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||
if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||
return
|
||||
}
|
||||
f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort))
|
||||
}
|
||||
|
||||
func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey {
|
||||
return conntrack.ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
@@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
||||
f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -52,36 +52,37 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
payload := fullPacket.AsSlice()
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
||||
f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||
f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
||||
rxBytes := pkt.Size()
|
||||
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||
return
|
||||
f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
n, _, err := conn.ReadFrom(response)
|
||||
if err != nil {
|
||||
if !isTimeout(err) {
|
||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||
f.logger.Error("forwarder: Failed to read ICMP response: %v", err)
|
||||
}
|
||||
return
|
||||
return 0
|
||||
}
|
||||
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
@@ -100,28 +101,54 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
|
||||
fullPacket = append(fullPacket, response[:n]...)
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||
f.logger.Error("forwarder: Failed to inject ICMP response: %v", err)
|
||||
|
||||
return
|
||||
return 0
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
|
||||
f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
// sendICMPEvent stores flow events for ICMP packets
|
||||
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
|
||||
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) {
|
||||
var rxPackets, txPackets uint64
|
||||
if rxBytes > 0 {
|
||||
rxPackets = 1
|
||||
}
|
||||
if txBytes > 0 {
|
||||
txPackets = 1
|
||||
}
|
||||
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.ICMP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
|
||||
// TODO: get packets/bytes
|
||||
})
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
RxPackets: rxPackets,
|
||||
TxPackets: txPackets,
|
||||
}
|
||||
|
||||
if typ == nftypes.TypeStart {
|
||||
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||
fields.RuleID = ruleId
|
||||
}
|
||||
} else {
|
||||
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
@@ -23,11 +25,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -65,67 +67,97 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||
defer func() {
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
ep.Close()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
|
||||
}()
|
||||
|
||||
// Create context for managing the proxy goroutines
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(outConn, inConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(inConn, outConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||
<-ctx.Done()
|
||||
// Close connections and endpoint.
|
||||
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var (
|
||||
bytesFromInToOut int64 // bytes from client to server (tx for client)
|
||||
bytesFromOutToIn int64 // bytes from server to client (rx for client)
|
||||
errInToOut error
|
||||
errOutToIn error
|
||||
)
|
||||
|
||||
go func() {
|
||||
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
||||
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errInToOut != nil {
|
||||
if !isClosedError(errInToOut) {
|
||||
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
|
||||
return
|
||||
}
|
||||
if errOutToIn != nil {
|
||||
if !isClosedError(errOutToIn) {
|
||||
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn)
|
||||
}
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
rxPackets = tcpStats.SegmentsSent.Value()
|
||||
txPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
|
||||
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.TCP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
SourcePort: id.RemotePort,
|
||||
DestPort: id.LocalPort,
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
RxPackets: rxPackets,
|
||||
TxPackets: txPackets,
|
||||
}
|
||||
|
||||
if ep != nil {
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
// TODO: get bytes
|
||||
fields.RxPackets = tcpStats.SegmentsSent.Value()
|
||||
fields.TxPackets = tcpStats.SegmentsReceived.Value()
|
||||
if typ == nftypes.TypeStart {
|
||||
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||
fields.RuleID = ruleId
|
||||
}
|
||||
} else {
|
||||
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
|
||||
@@ -149,11 +149,11 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||
f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -199,7 +199,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
@@ -212,68 +211,94 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
defer func() {
|
||||
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
pConn.cancel()
|
||||
if err := pConn.conn.Close(); err != nil {
|
||||
if err := pConn.conn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
if err := pConn.outConn.Close(); err != nil {
|
||||
if err := pConn.outConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
|
||||
}()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var txBytes, rxBytes int64
|
||||
var outboundErr, inboundErr error
|
||||
|
||||
// outbound->inbound: copy from pConn.conn to pConn.outConn
|
||||
go func() {
|
||||
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||
defer wg.Done()
|
||||
txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||
}()
|
||||
|
||||
// inbound->outbound: copy from pConn.outConn to pConn.conn
|
||||
go func() {
|
||||
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||
defer wg.Done()
|
||||
rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
|
||||
return
|
||||
wg.Wait()
|
||||
|
||||
if outboundErr != nil && !isClosedError(outboundErr) {
|
||||
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr)
|
||||
}
|
||||
if inboundErr != nil && !isClosedError(inboundErr) {
|
||||
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr)
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
rxPackets = udpStats.PacketsSent.Value()
|
||||
txPackets = udpStats.PacketsReceived.Value()
|
||||
}
|
||||
|
||||
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
// sendUDPEvent stores flow events for UDP connections
|
||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.UDP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
SourcePort: id.RemotePort,
|
||||
DestPort: id.LocalPort,
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
RxPackets: rxPackets,
|
||||
TxPackets: txPackets,
|
||||
}
|
||||
|
||||
if ep != nil {
|
||||
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
// TODO: get bytes
|
||||
fields.RxPackets = tcpStats.PacketsSent.Value()
|
||||
fields.TxPackets = tcpStats.PacketsReceived.Value()
|
||||
if typ == nftypes.TypeStart {
|
||||
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||
fields.RuleID = ruleId
|
||||
}
|
||||
} else {
|
||||
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
@@ -288,18 +313,20 @@ func (c *udpPacketConn) getIdleDuration() time.Duration {
|
||||
return time.Since(lastSeen)
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
||||
// copy reads from src and writes to dst.
|
||||
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) {
|
||||
bufp := bufPool.Get().(*[]byte)
|
||||
defer bufPool.Put(bufp)
|
||||
buffer := *bufp
|
||||
var totalBytes int64 = 0
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
return totalBytes, ctx.Err()
|
||||
}
|
||||
|
||||
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||
return fmt.Errorf("set read deadline: %w", err)
|
||||
return totalBytes, fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
|
||||
n, err := src.Read(buffer)
|
||||
@@ -307,14 +334,15 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("read from %s: %w", direction, err)
|
||||
return totalBytes, fmt.Errorf("read from %s: %w", direction, err)
|
||||
}
|
||||
|
||||
_, err = dst.Write(buffer[:n])
|
||||
nWritten, err := dst.Write(buffer[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("write to %s: %w", direction, err)
|
||||
return totalBytes, fmt.Errorf("write to %s: %w", direction, err)
|
||||
}
|
||||
|
||||
totalBytes += int64(nWritten)
|
||||
c.updateLastSeen()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,8 +14,13 @@ import (
|
||||
type localIPManager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
||||
ipv4Bitmap [1 << 16]uint32
|
||||
// fixed-size high array for upper byte of a IPv4 address
|
||||
ipv4Bitmap [256]*ipv4LowBitmap
|
||||
}
|
||||
|
||||
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
||||
type ipv4LowBitmap struct {
|
||||
bitmap [8192]uint32
|
||||
}
|
||||
|
||||
func newLocalIPManager() *localIPManager {
|
||||
@@ -27,35 +32,59 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
if ipv4 == nil {
|
||||
return
|
||||
}
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := (uint16(ip[0]) << 8) | uint16(ip[1])
|
||||
low := (uint16(ip[2]) << 8) | uint16(ip[3])
|
||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
if int(high) >= len(*newIPv4Bitmap) {
|
||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
ipStr := ip.String()
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
|
||||
ipStr := ipv4.String()
|
||||
if _, exists := ipv4Set[ipStr]; !exists {
|
||||
ipv4Set[ipStr] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||
newIPv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := uint16(ip[0])
|
||||
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
@@ -73,7 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -86,14 +115,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [1 << 16]uint32
|
||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||
ipv4Set := make(map[string]struct{})
|
||||
var ipv4Addresses []string
|
||||
|
||||
// 127.0.0.0/8
|
||||
high := uint16(127) << 8
|
||||
for i := uint16(0); i < 256; i++ {
|
||||
newIPv4Bitmap[high|i] = 0xffffffff
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
for i := 0; i < 8192; i++ {
|
||||
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||
}
|
||||
|
||||
if iface != nil {
|
||||
@@ -120,12 +149,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}
|
||||
|
||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||
if !ip.Is4() {
|
||||
return false
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if ip.Is4() {
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
}
|
||||
|
||||
return false
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
}
|
||||
|
||||
@@ -77,6 +77,18 @@ func TestLocalIPManager(t *testing.T) {
|
||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Local IP doesn't match - addresses 32 apart",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: wgaddr.Address{
|
||||
@@ -192,10 +204,8 @@ func BenchmarkIPChecks(b *testing.B) {
|
||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||
}
|
||||
|
||||
// Setup bitmap version
|
||||
bitmapManager := &localIPManager{
|
||||
ipv4Bitmap: [1 << 16]uint32{},
|
||||
}
|
||||
// Setup bitmap
|
||||
bitmapManager := newLocalIPManager()
|
||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||
bitmapManager.setBitmapBit(ip)
|
||||
}
|
||||
@@ -248,7 +258,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
||||
|
||||
// Create two managers - one checks WG IP first, other checks it last
|
||||
b.Run("WG_First", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
bm := newLocalIPManager()
|
||||
bm.setBitmapBit(wgIP)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@@ -257,7 +267,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("WG_Last", func(b *testing.B) {
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
bm := newLocalIPManager()
|
||||
// Fill with other IPs first
|
||||
for i := 0; i < 15; i++ {
|
||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||
|
||||
@@ -29,14 +29,15 @@ func (r *PeerRule) ID() string {
|
||||
}
|
||||
|
||||
type RouteRule struct {
|
||||
id string
|
||||
mgmtId []byte
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
id string
|
||||
mgmtId []byte
|
||||
sources []netip.Prefix
|
||||
dstSet firewall.Set
|
||||
destinations []netip.Prefix
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
|
||||
@@ -198,12 +198,12 @@ func TestTracePacket(t *testing.T) {
|
||||
m.forwarder.Store(&forwarder.Forwarder{})
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
@@ -222,12 +222,12 @@ func TestTracePacket(t *testing.T) {
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
@@ -245,7 +245,7 @@ func TestTracePacket(t *testing.T) {
|
||||
m.nativeRouter.Store(true)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
@@ -263,7 +263,7 @@ func TestTracePacket(t *testing.T) {
|
||||
m.routingEnabled.Store(false)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
@@ -425,8 +425,8 @@ func TestTracePacket(t *testing.T) {
|
||||
|
||||
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||
"100.10.0.100 should be recognized as a local IP")
|
||||
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
|
||||
"172.17.0.2 should not be recognized as a local IP")
|
||||
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")),
|
||||
"192.168.17.2 should not be recognized as a local IP")
|
||||
|
||||
pb := tc.packetBuilder()
|
||||
|
||||
|
||||
@@ -49,10 +49,10 @@ var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
type RouteRules []RouteRule
|
||||
type RouteRules []*RouteRule
|
||||
|
||||
func (r RouteRules) Sort() {
|
||||
slices.SortStableFunc(r, func(a, b RouteRule) int {
|
||||
slices.SortStableFunc(r, func(a, b *RouteRule) int {
|
||||
// Deny rules come first
|
||||
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||
return -1
|
||||
@@ -99,6 +99,8 @@ type Manager struct {
|
||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
|
||||
blockRule firewall.Rule
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -201,41 +203,35 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.blockInvalidRouted(iface); err != nil {
|
||||
log.Errorf("failed to block invalid routed traffic: %v", err)
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
return nil, fmt.Errorf("set filter: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||
if m.forwarder.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
|
||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse wireguard network: %w", err)
|
||||
return nil, fmt.Errorf("parse wireguard network: %w", err)
|
||||
}
|
||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||
|
||||
if _, err := m.AddRouteFiltering(
|
||||
rule, err := m.addRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
wgPrefix,
|
||||
firewall.Network{Prefix: wgPrefix},
|
||||
firewall.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionDrop,
|
||||
); err != nil {
|
||||
return fmt.Errorf("block wg nte : %w", err)
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("block wg nte : %w", err)
|
||||
}
|
||||
|
||||
// TODO: Block networks that we're a client of
|
||||
|
||||
return nil
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) determineRouting() error {
|
||||
@@ -413,10 +409,23 @@ func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
func (m *Manager) addRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
@@ -426,34 +435,39 @@ func (m *Manager) AddRouteFiltering(
|
||||
ruleID := uuid.New().String()
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
destination: destination,
|
||||
proto: proto,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
proto: proto,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
if destination.IsPrefix() {
|
||||
rule.destinations = []netip.Prefix{destination.Prefix}
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
m.routeRules = append(m.routeRules, rule)
|
||||
m.routeRules = append(m.routeRules, &rule)
|
||||
m.routeRules.Sort()
|
||||
m.mutex.Unlock()
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.deleteRouteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
ruleID := rule.ID()
|
||||
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||
return r.id == ruleID
|
||||
})
|
||||
if idx < 0 {
|
||||
@@ -509,6 +523,52 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the rule destinations associated with the given set
|
||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var matches []*RouteRule
|
||||
for _, rule := range m.routeRules {
|
||||
if rule.dstSet == set {
|
||||
matches = append(matches, rule)
|
||||
}
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return fmt.Errorf("no route rule found for set: %s", set)
|
||||
}
|
||||
|
||||
destinations := matches[0].destinations
|
||||
for _, prefix := range prefixes {
|
||||
if prefix.Addr().Is4() {
|
||||
destinations = append(destinations, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
|
||||
cmp := a.Addr().Compare(b.Addr())
|
||||
if cmp != 0 {
|
||||
return cmp
|
||||
}
|
||||
return a.Bits() - b.Bits()
|
||||
})
|
||||
|
||||
destinations = slices.Compact(destinations)
|
||||
|
||||
for _, rule := range matches {
|
||||
rule.destinations = destinations
|
||||
}
|
||||
log.Debugf("updated set %s to prefixes %v", set.HashedName(), destinations)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropOutgoing filter outgoing packets
|
||||
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
||||
return m.processOutgoingHooks(packetData, size)
|
||||
@@ -658,7 +718,8 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if !m.isValidPacket(d, packetData) {
|
||||
valid, fragment := m.isValidPacket(d, packetData)
|
||||
if !valid {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -668,6 +729,13 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO: pass fragments of routed packets to forwarder
|
||||
if fragment {
|
||||
m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
return false
|
||||
}
|
||||
|
||||
// For all inbound traffic, first check if it matches a tracked connection.
|
||||
// This must happen before any other filtering because the packets are statefully tracked.
|
||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||
@@ -678,7 +746,7 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
|
||||
}
|
||||
|
||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData, size)
|
||||
}
|
||||
|
||||
// handleLocalTraffic handles local traffic.
|
||||
@@ -739,7 +807,7 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||
|
||||
// handleRoutedTraffic handles routed traffic.
|
||||
// If it returns true, the packet should be dropped.
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled.Load() {
|
||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
@@ -749,13 +817,15 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
|
||||
// Pass to native stack if native router is enabled or forced
|
||||
if m.nativeRouter.Load() {
|
||||
m.trackInbound(d, srcIP, dstIP, nil, size)
|
||||
return false
|
||||
}
|
||||
|
||||
proto, pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
||||
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
if !pass {
|
||||
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
@@ -770,6 +840,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
// TODO: icmp type/code
|
||||
RxPackets: 1,
|
||||
RxBytes: uint64(size),
|
||||
})
|
||||
return true
|
||||
}
|
||||
@@ -779,8 +851,11 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
if fwd == nil {
|
||||
m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
|
||||
} else {
|
||||
fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID)
|
||||
|
||||
if err := fwd.InjectIncomingPacket(packetData); err != nil {
|
||||
m.logger.Error("Failed to inject routed packet: %v", err)
|
||||
fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -812,17 +887,32 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
||||
// isValidPacket checks if the packet is valid.
|
||||
// It returns true, false if the packet is valid and not a fragment.
|
||||
// It returns true, true if the packet is a fragment and valid.
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Trace("couldn't decode packet, err: %s", err)
|
||||
return false
|
||||
return false, false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
m.logger.Trace("packet doesn't have network and transport layers")
|
||||
return false
|
||||
l := len(d.decoded)
|
||||
|
||||
// L3 and L4 are mandatory
|
||||
if l >= 2 {
|
||||
return true, false
|
||||
}
|
||||
return true
|
||||
|
||||
// Fragments are also valid
|
||||
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
|
||||
ip4 := d.ip4
|
||||
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
|
||||
return true, true
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Trace("packet doesn't have network and transport layers")
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
|
||||
@@ -962,8 +1052,15 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
if !rule.destination.Contains(dstAddr) {
|
||||
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||
destMatched := false
|
||||
for _, dst := range rule.destinations {
|
||||
if dst.Contains(dstAddr) {
|
||||
destMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !destMatched {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1065,7 +1162,22 @@ func (m *Manager) EnableRouting() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.determineRouting()
|
||||
if err := m.determineRouting(); err != nil {
|
||||
return fmt.Errorf("determine routing: %w", err)
|
||||
}
|
||||
|
||||
if m.forwarder.Load() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rule, err := m.blockInvalidRouted(m.wgIface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("block invalid routed: %w", err)
|
||||
}
|
||||
|
||||
m.blockRule = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
@@ -1090,5 +1202,12 @@ func (m *Manager) DisableRouting() error {
|
||||
|
||||
log.Debug("forwarder stopped")
|
||||
|
||||
if m.blockRule != nil {
|
||||
if err := m.deleteRouteRule(m.blockRule); err != nil {
|
||||
return fmt.Errorf("delete block rule: %w", err)
|
||||
}
|
||||
m.blockRule = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
func TestPeerACLFiltering(t *testing.T) {
|
||||
@@ -188,6 +189,281 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Allow TCP traffic without port specification",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Allow UDP traffic without port specification",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "TCP packet doesn't match UDP filter with same port",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "UDP packet doesn't match TCP filter with same port",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "ICMP packet doesn't match TCP filter",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "ICMP packet doesn't match UDP filter",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Allow TCP traffic within port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Block TCP traffic outside port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 7999,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Edge Case - Port at Range Boundary",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8100,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "UDP Port Range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 5060,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Allow multiple destination ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Allow multiple source ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
// New drop test cases
|
||||
{
|
||||
name: "Drop TCP traffic from WG peer",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop UDP traffic from WG peer",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{53}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop ICMP traffic from WG peer",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolICMP,
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop all traffic from WG peer",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolALL,
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop traffic from multiple source ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop multiple destination ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop TCP traffic within port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Accept TCP traffic outside drop port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 7999,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Drop TCP traffic with source port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 32100,
|
||||
dstPort: 80,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Mixed rule - drop specific port but allow other ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||
@@ -198,6 +474,28 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
if tc.ruleAction == fw.ActionDrop {
|
||||
// add general accept rule to test drop rule
|
||||
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
nil,
|
||||
net.ParseIP("0.0.0.0"),
|
||||
fw.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
nil,
|
||||
net.ParseIP(tc.ruleIP),
|
||||
@@ -303,8 +601,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NoError(tb, err)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NotNil(tb, manager)
|
||||
require.True(tb, manager.routingEnabled.Load())
|
||||
require.False(tb, manager.nativeRouter.Load())
|
||||
@@ -321,7 +619,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
|
||||
type rule struct {
|
||||
sources []netip.Prefix
|
||||
dest netip.Prefix
|
||||
dest fw.Network
|
||||
proto fw.Protocol
|
||||
srcPort *fw.Port
|
||||
dstPort *fw.Port
|
||||
@@ -347,7 +645,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -363,7 +661,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -379,7 +677,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -395,7 +693,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 53,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
dstPort: &fw.Port{Values: []uint16{53}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -409,7 +707,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
proto: fw.ProtocolICMP,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
proto: fw.ProtocolICMP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
@@ -424,7 +722,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolALL,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -440,7 +738,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8080,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -456,7 +754,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -472,7 +770,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -488,7 +786,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: &fw.Port{Values: []uint16{12345}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -507,7 +805,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
netip.MustParsePrefix("100.10.0.0/16"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -521,7 +819,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
proto: fw.ProtocolICMP,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolALL,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
@@ -536,33 +834,13 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolALL,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple source networks with mismatched protocol",
|
||||
srcIP: "172.16.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
// Should not match TCP rule
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("100.10.0.0/16"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "Allow multiple destination ports",
|
||||
srcIP: "100.10.0.1",
|
||||
@@ -572,7 +850,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8080,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -588,7 +866,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -604,7 +882,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolALL,
|
||||
srcPort: &fw.Port{Values: []uint16{12345}},
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
@@ -621,7 +899,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8080,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -640,7 +918,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 7999,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -659,7 +937,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -678,7 +956,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -700,7 +978,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8100,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -719,7 +997,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 5060,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -738,7 +1016,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8080,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolALL,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -757,7 +1035,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionDrop,
|
||||
@@ -773,7 +1051,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolALL,
|
||||
action: fw.ActionDrop,
|
||||
},
|
||||
@@ -791,17 +1069,158 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
netip.MustParsePrefix("100.10.0.0/16"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionDrop,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "Drop empty destination set",
|
||||
srcIP: "172.16.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
dest: fw.Network{Set: fw.Set{}},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "Accept TCP traffic outside drop port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 7999,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
action: fw.ActionDrop,
|
||||
},
|
||||
shouldPass: true,
|
||||
},
|
||||
{
|
||||
name: "Allow TCP traffic without port specification",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: true,
|
||||
},
|
||||
{
|
||||
name: "Allow UDP traffic without port specification",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: true,
|
||||
},
|
||||
{
|
||||
name: "TCP packet doesn't match UDP filter with same port",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "UDP packet doesn't match TCP filter with same port",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP packet doesn't match TCP filter",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP packet doesn't match UDP filter",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.rule.action == fw.ActionDrop {
|
||||
// add general accept rule to test drop rule
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
})
|
||||
}
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
tc.rule.sources,
|
||||
@@ -836,7 +1255,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
name string
|
||||
rules []struct {
|
||||
sources []netip.Prefix
|
||||
dest netip.Prefix
|
||||
dest fw.Network
|
||||
proto fw.Protocol
|
||||
srcPort *fw.Port
|
||||
dstPort *fw.Port
|
||||
@@ -857,7 +1276,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
name: "Drop rules take precedence over accept",
|
||||
rules: []struct {
|
||||
sources []netip.Prefix
|
||||
dest netip.Prefix
|
||||
dest fw.Network
|
||||
proto fw.Protocol
|
||||
srcPort *fw.Port
|
||||
dstPort *fw.Port
|
||||
@@ -866,7 +1285,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
{
|
||||
// Accept rule added first
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80, 443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -874,7 +1293,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
{
|
||||
// Drop rule added second but should be evaluated first
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionDrop,
|
||||
@@ -912,7 +1331,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
name: "Multiple drop rules take precedence",
|
||||
rules: []struct {
|
||||
sources []netip.Prefix
|
||||
dest netip.Prefix
|
||||
dest fw.Network
|
||||
proto fw.Protocol
|
||||
srcPort *fw.Port
|
||||
dstPort *fw.Port
|
||||
@@ -921,14 +1340,14 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
{
|
||||
// Accept all
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
proto: fw.ProtocolALL,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
{
|
||||
// Drop specific port
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionDrop,
|
||||
@@ -936,7 +1355,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
{
|
||||
// Drop different port
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionDrop,
|
||||
@@ -1015,3 +1434,53 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteACLSet(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
// Add rule that uses the set (initially empty)
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||
|
||||
// Check that traffic is dropped (empty set shouldn't match anything)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.False(t, isAllowed, "Empty set should not allow any traffic")
|
||||
|
||||
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now the packet should be allowed
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -21,10 +20,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
@@ -712,3 +712,203 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSetMerge(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
initialPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
// Update the set with initial prefixes
|
||||
err = manager.UpdateSet(set, initialPrefixes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test initial prefixes work
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
dstIP1 := netip.MustParseAddr("10.0.0.100")
|
||||
dstIP2 := netip.MustParseAddr("192.168.1.100")
|
||||
dstIP3 := netip.MustParseAddr("172.16.0.100")
|
||||
|
||||
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
|
||||
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
|
||||
require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied")
|
||||
|
||||
newPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
}
|
||||
|
||||
err = manager.UpdateSet(set, newPrefixes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that all original prefixes are still included
|
||||
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
|
||||
|
||||
// Check that new prefixes are included
|
||||
dstIP4 := netip.MustParseAddr("172.16.1.100")
|
||||
dstIP5 := netip.MustParseAddr("10.1.0.50")
|
||||
|
||||
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
|
||||
|
||||
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
|
||||
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
|
||||
|
||||
// Verify the rule has all prefixes
|
||||
manager.mutex.RLock()
|
||||
foundRule := false
|
||||
for _, r := range manager.routeRules {
|
||||
if r.id == rule.ID() {
|
||||
foundRule = true
|
||||
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
|
||||
"Rule should have all prefixes merged")
|
||||
}
|
||||
}
|
||||
manager.mutex.RUnlock()
|
||||
require.True(t, foundRule, "Rule should be found")
|
||||
}
|
||||
|
||||
func TestUpdateSetDeduplication(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
initialPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"), // Duplicate
|
||||
}
|
||||
|
||||
err = manager.UpdateSet(set, initialPrefixes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the internal state for deduplication
|
||||
manager.mutex.RLock()
|
||||
foundRule := false
|
||||
for _, r := range manager.routeRules {
|
||||
if r.id == rule.ID() {
|
||||
foundRule = true
|
||||
// Should have deduplicated to 2 prefixes
|
||||
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
|
||||
|
||||
// Check the prefixes are correct
|
||||
expectedPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
for i, prefix := range expectedPrefixes {
|
||||
require.True(t, r.destinations[i] == prefix,
|
||||
"Prefix should match expected value")
|
||||
}
|
||||
}
|
||||
}
|
||||
manager.mutex.RUnlock()
|
||||
require.True(t, foundRule, "Rule should be found")
|
||||
|
||||
// Test with overlapping prefixes of different sizes
|
||||
overlappingPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/16"), // More general
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists)
|
||||
netip.MustParsePrefix("192.168.0.0/16"), // More general
|
||||
netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists)
|
||||
}
|
||||
|
||||
err = manager.UpdateSet(set, overlappingPrefixes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that all prefixes are included (no deduplication of overlapping prefixes)
|
||||
manager.mutex.RLock()
|
||||
for _, r := range manager.routeRules {
|
||||
if r.id == rule.ID() {
|
||||
// Should have all 4 prefixes (2 original + 2 new more general ones)
|
||||
require.Len(t, r.destinations, 4,
|
||||
"Overlapping prefixes should not be deduplicated")
|
||||
|
||||
// Verify they're sorted correctly (more specific prefixes should come first)
|
||||
prefixes := make([]string, 0, len(r.destinations))
|
||||
for _, p := range r.destinations {
|
||||
prefixes = append(prefixes, p.String())
|
||||
}
|
||||
|
||||
// Check sorted order
|
||||
require.Equal(t, []string{
|
||||
"10.0.0.0/16",
|
||||
"10.0.0.0/24",
|
||||
"192.168.0.0/16",
|
||||
"192.168.1.0/24",
|
||||
}, prefixes, "Prefixes should be sorted")
|
||||
}
|
||||
}
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
// Test functionality with all prefixes
|
||||
testCases := []struct {
|
||||
dstIP netip.Addr
|
||||
expected bool
|
||||
desc string
|
||||
}{
|
||||
{netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"},
|
||||
{netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"},
|
||||
{netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"},
|
||||
{netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"},
|
||||
{netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"},
|
||||
}
|
||||
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
for _, tc := range testCases {
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,7 +150,7 @@ func isZeros(ip net.IP) bool {
|
||||
// NewUDPMuxDefault creates an implementation of UDPMux
|
||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
if params.Logger == nil {
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
params.Logger = getLogger()
|
||||
}
|
||||
|
||||
mux := &UDPMuxDefault{
|
||||
@@ -455,3 +455,9 @@ func newBufferHolder(size int) *bufferHolder {
|
||||
buf: make([]byte, size),
|
||||
}
|
||||
}
|
||||
|
||||
func getLogger() logging.LeveledLogger {
|
||||
fac := logging.NewDefaultLoggerFactory()
|
||||
//fac.Writer = log.StandardLogger().Writer()
|
||||
return fac.NewLogger("ice")
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ type UniversalUDPMuxParams struct {
|
||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
|
||||
if params.Logger == nil {
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
params.Logger = getLogger()
|
||||
}
|
||||
if params.XORMappedAddrCacheTTL == 0 {
|
||||
params.XORMappedAddrCacheTTL = time.Second * 25
|
||||
|
||||
@@ -357,7 +357,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||
|
||||
func getFwmark() int {
|
||||
if nbnet.AdvancedRouting() {
|
||||
return nbnet.NetbirdFwmark
|
||||
return nbnet.ControlPlaneMark
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func (r RuleID) ID() string {
|
||||
|
||||
func GenerateRouteRuleKey(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
destination manager.Network,
|
||||
proto manager.Protocol,
|
||||
sPort *manager.Port,
|
||||
dPort *manager.Port,
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
@@ -25,7 +26,7 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||
|
||||
// Manager is a ACL rules manager
|
||||
type Manager interface {
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||
}
|
||||
|
||||
type protoMatch struct {
|
||||
@@ -53,7 +54,7 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||
//
|
||||
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
|
||||
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
|
||||
@@ -82,7 +83,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
log.Errorf("failed to set legacy management flag: %v", err)
|
||||
}
|
||||
|
||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
|
||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||
}
|
||||
|
||||
@@ -176,16 +177,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
d.peerRulesPairs = newRulePairs
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
|
||||
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
||||
var merr *multierror.Error
|
||||
|
||||
// Apply new rules - firewall manager will return existing rule ID if already present
|
||||
for _, rule := range rules {
|
||||
id, err := d.applyRouteACL(rule)
|
||||
id, err := d.applyRouteACL(rule, dynamicResolver)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
|
||||
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
|
||||
} else {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
|
||||
}
|
||||
@@ -208,7 +209,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) err
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
|
||||
if len(rule.SourceRanges) == 0 {
|
||||
return "", ErrSourceRangesEmpty
|
||||
}
|
||||
@@ -222,15 +223,9 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
|
||||
sources = append(sources, source)
|
||||
}
|
||||
|
||||
var destination netip.Prefix
|
||||
if rule.IsDynamic {
|
||||
destination = getDefault(sources[0])
|
||||
} else {
|
||||
var err error
|
||||
destination, err = netip.ParsePrefix(rule.Destination)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse destination: %w", err)
|
||||
}
|
||||
destination, err := determineDestination(rule, dynamicResolver, sources)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("determine destination: %w", err)
|
||||
}
|
||||
|
||||
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
||||
@@ -580,6 +575,33 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
|
||||
return nil
|
||||
}
|
||||
|
||||
func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) {
|
||||
var destination firewall.Network
|
||||
|
||||
if rule.IsDynamic {
|
||||
if dynamicResolver {
|
||||
if len(rule.Domains) > 0 {
|
||||
destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains))
|
||||
} else {
|
||||
// isDynamic is set but no domains = outdated management server
|
||||
log.Warn("connected to an older version of management server (no domains in rules), using default destination")
|
||||
destination.Prefix = getDefault(sources[0])
|
||||
}
|
||||
} else {
|
||||
// client resolves DNS, we (router) don't know the destination
|
||||
destination.Prefix = getDefault(sources[0])
|
||||
}
|
||||
return destination, nil
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(rule.Destination)
|
||||
if err != nil {
|
||||
return destination, fmt.Errorf("parse destination: %w", err)
|
||||
}
|
||||
destination.Prefix = prefix
|
||||
return destination, nil
|
||||
}
|
||||
|
||||
func getDefault(prefix netip.Prefix) netip.Prefix {
|
||||
if prefix.Addr().Is6() {
|
||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
@@ -15,7 +14,7 @@ import (
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
func TestDefaultManager(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
@@ -67,7 +66,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
t.Run("apply firewall rules", func(t *testing.T) {
|
||||
acl.ApplyFiltering(networkMap)
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
|
||||
@@ -93,7 +92,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
},
|
||||
)
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
// we should have one old and one new rule in the existed rules
|
||||
if len(acl.peerRulesPairs) != 2 {
|
||||
@@ -117,13 +116,13 @@ func TestDefaultManager(t *testing.T) {
|
||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = true
|
||||
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
|
||||
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
|
||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
|
||||
return
|
||||
}
|
||||
|
||||
networkMap.FirewallRulesIsEmpty = false
|
||||
acl.ApplyFiltering(networkMap)
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
if len(acl.peerRulesPairs) != 1 {
|
||||
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||
return
|
||||
@@ -360,7 +359,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
acl.ApplyFiltering(networkMap)
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
if len(acl.peerRulesPairs) != 3 {
|
||||
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||
|
||||
@@ -94,12 +94,17 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
p.codeVerifier = codeVerifier
|
||||
|
||||
codeChallenge := createCodeChallenge(codeVerifier)
|
||||
authURL := p.oAuthConfig.AuthCodeURL(
|
||||
state,
|
||||
|
||||
params := []oauth2.AuthCodeOption{
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||
)
|
||||
}
|
||||
if !p.providerConfig.DisablePromptLogin {
|
||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||
}
|
||||
|
||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||
|
||||
return AuthFlowInfo{
|
||||
VerificationURIComplete: authURL,
|
||||
|
||||
49
client/internal/auth/pkce_flow_test.go
Normal file
49
client/internal/auth/pkce_flow_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
func TestPromptLogin(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
prompt bool
|
||||
}{
|
||||
{"PromptLogin", true},
|
||||
{"NoPromptLogin", false},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
Scope: "openid email profile",
|
||||
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||
UseIDToken: true,
|
||||
DisablePromptLogin: !tc.prompt,
|
||||
}
|
||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create PKCEAuthorizationFlow: %v", err)
|
||||
}
|
||||
authInfo, err := pkce.RequestAuthInfo(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to request auth info: %v", err)
|
||||
}
|
||||
pattern := "prompt=login"
|
||||
if tc.prompt {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, pattern)
|
||||
} else {
|
||||
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -349,6 +349,25 @@ func (c *ConnectClient) Engine() *Engine {
|
||||
return e
|
||||
}
|
||||
|
||||
// GetLatestNetworkMap returns the latest network map from the engine.
|
||||
func (c *ConnectClient) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||
engine := c.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine is not initialized")
|
||||
}
|
||||
|
||||
networkMap, err := engine.GetLatestNetworkMap()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get latest network map: %w", err)
|
||||
}
|
||||
|
||||
if networkMap == nil {
|
||||
return nil, errors.New("network map is not available")
|
||||
}
|
||||
|
||||
return networkMap, nil
|
||||
}
|
||||
|
||||
// Status returns the current client status
|
||||
func (c *ConnectClient) Status() StatusType {
|
||||
if c == nil {
|
||||
|
||||
1022
client/internal/debug/debug.go
Normal file
1022
client/internal/debug/debug.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,8 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package server
|
||||
package debug
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
@@ -14,36 +13,31 @@ import (
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// addFirewallRules collects and adds firewall rules to the archive
|
||||
func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
||||
func (g *BundleGenerator) addFirewallRules() error {
|
||||
log.Info("Collecting firewall rules")
|
||||
// Collect and add iptables rules
|
||||
iptablesRules, err := collectIPTablesRules()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect iptables rules: %v", err)
|
||||
} else {
|
||||
if req.GetAnonymize() {
|
||||
iptablesRules = anonymizer.AnonymizeString(iptablesRules)
|
||||
if g.anonymize {
|
||||
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
|
||||
}
|
||||
if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
||||
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
||||
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Collect and add nftables rules
|
||||
nftablesRules, err := collectNFTablesRules()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect nftables rules: %v", err)
|
||||
} else {
|
||||
if req.GetAnonymize() {
|
||||
nftablesRules = anonymizer.AnonymizeString(nftablesRules)
|
||||
if g.anonymize {
|
||||
nftablesRules = g.anonymizer.AnonymizeString(nftablesRules)
|
||||
}
|
||||
if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
|
||||
if err := g.addFileToZip(strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
|
||||
log.Warnf("Failed to add nftables rules to bundle: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -65,16 +59,23 @@ func collectIPTablesRules() (string, error) {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
// Then get verbose statistics for each table
|
||||
// Collect ipset information
|
||||
ipsetOutput, err := collectIPSets()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to collect ipset information: %v", err)
|
||||
} else {
|
||||
builder.WriteString("=== ipset list output ===\n")
|
||||
builder.WriteString(ipsetOutput)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
builder.WriteString("=== iptables -v -n -L output ===\n")
|
||||
|
||||
// Get list of tables
|
||||
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
||||
|
||||
for _, table := range tables {
|
||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||
|
||||
// Get verbose statistics for the entire table
|
||||
stats, err := getTableStatistics(table)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
||||
@@ -87,6 +88,28 @@ func collectIPTablesRules() (string, error) {
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
// collectIPSets collects information about ipsets
|
||||
func collectIPSets() (string, error) {
|
||||
cmd := exec.Command("ipset", "list")
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if strings.Contains(err.Error(), "executable file not found") {
|
||||
return "", fmt.Errorf("ipset command not found: %w", err)
|
||||
}
|
||||
return "", fmt.Errorf("execute ipset list: %w (stderr: %s)", err, stderr.String())
|
||||
}
|
||||
|
||||
ipsets := stdout.String()
|
||||
if strings.TrimSpace(ipsets) == "" {
|
||||
return "No ipsets found", nil
|
||||
}
|
||||
|
||||
return ipsets, nil
|
||||
}
|
||||
|
||||
// collectIPTablesSave uses iptables-save to get rule definitions
|
||||
func collectIPTablesSave() (string, error) {
|
||||
cmd := exec.Command("iptables-save")
|
||||
@@ -182,12 +205,10 @@ func formatTables(conn *nftables.Conn, tables []*nftables.Table) string {
|
||||
continue
|
||||
}
|
||||
|
||||
// Format chains
|
||||
for _, chain := range chains {
|
||||
formatChain(conn, table, chain, &builder)
|
||||
}
|
||||
|
||||
// Format sets
|
||||
if sets, err := conn.GetSets(table); err != nil {
|
||||
log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
|
||||
} else if len(sets) > 0 {
|
||||
7
client/internal/debug/debug_mobile.go
Normal file
7
client/internal/debug/debug_mobile.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build ios || android
|
||||
|
||||
package debug
|
||||
|
||||
func (g *BundleGenerator) addRoutes() error {
|
||||
return nil
|
||||
}
|
||||
8
client/internal/debug/debug_nonlinux.go
Normal file
8
client/internal/debug/debug_nonlinux.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package debug
|
||||
|
||||
// collectFirewallRules returns nothing on non-linux systems
|
||||
func (g *BundleGenerator) addFirewallRules() error {
|
||||
return nil
|
||||
}
|
||||
25
client/internal/debug/debug_nonmobile.go
Normal file
25
client/internal/debug/debug_nonmobile.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func (g *BundleGenerator) addRoutes() error {
|
||||
routes, err := systemops.GetRoutesFromTable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get routes: %w", err)
|
||||
}
|
||||
|
||||
// TODO: get routes including nexthop
|
||||
routesContent := formatRoutes(routes, g.anonymize, g.anonymizer)
|
||||
routesReader := strings.NewReader(routesContent)
|
||||
if err := g.addFileToZip(routesReader, "routes.txt"); err != nil {
|
||||
return fmt.Errorf("add routes file to zip: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
543
client/internal/debug/debug_test.go
Normal file
543
client/internal/debug/debug_test.go
Normal file
@@ -0,0 +1,543 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
func TestAnonymizeStateFile(t *testing.T) {
|
||||
testState := map[string]json.RawMessage{
|
||||
"null_state": json.RawMessage("null"),
|
||||
"test_state": mustMarshal(map[string]any{
|
||||
// Test simple fields
|
||||
"public_ip": "203.0.113.1",
|
||||
"private_ip": "192.168.1.1",
|
||||
"protected_ip": "100.64.0.1",
|
||||
"well_known_ip": "8.8.8.8",
|
||||
"ipv6_addr": "2001:db8::1",
|
||||
"private_ipv6": "fd00::1",
|
||||
"domain": "test.example.com",
|
||||
"uri": "stun:stun.example.com:3478",
|
||||
"uri_with_ip": "turn:203.0.113.1:3478",
|
||||
"netbird_domain": "device.netbird.cloud",
|
||||
|
||||
// Test CIDR ranges
|
||||
"public_cidr": "203.0.113.0/24",
|
||||
"private_cidr": "192.168.0.0/16",
|
||||
"protected_cidr": "100.64.0.0/10",
|
||||
"ipv6_cidr": "2001:db8::/32",
|
||||
"private_ipv6_cidr": "fd00::/8",
|
||||
|
||||
// Test nested structures
|
||||
"nested": map[string]any{
|
||||
"ip": "203.0.113.2",
|
||||
"domain": "nested.example.com",
|
||||
"more_nest": map[string]any{
|
||||
"ip": "203.0.113.3",
|
||||
"domain": "deep.example.com",
|
||||
},
|
||||
},
|
||||
|
||||
// Test arrays
|
||||
"string_array": []any{
|
||||
"203.0.113.4",
|
||||
"test1.example.com",
|
||||
"test2.example.com",
|
||||
},
|
||||
"object_array": []any{
|
||||
map[string]any{
|
||||
"ip": "203.0.113.5",
|
||||
"domain": "array1.example.com",
|
||||
},
|
||||
map[string]any{
|
||||
"ip": "203.0.113.6",
|
||||
"domain": "array2.example.com",
|
||||
},
|
||||
},
|
||||
|
||||
// Test multiple occurrences of same value
|
||||
"duplicate_ip": "203.0.113.1", // Same as public_ip
|
||||
"duplicate_domain": "test.example.com", // Same as domain
|
||||
|
||||
// Test URIs with various schemes
|
||||
"stun_uri": "stun:stun.example.com:3478",
|
||||
"turns_uri": "turns:turns.example.com:5349",
|
||||
"http_uri": "http://web.example.com:80",
|
||||
"https_uri": "https://secure.example.com:443",
|
||||
|
||||
// Test strings that might look like IPs but aren't
|
||||
"not_ip": "300.300.300.300",
|
||||
"partial_ip": "192.168",
|
||||
"ip_like_string": "1234.5678",
|
||||
|
||||
// Test mixed content strings
|
||||
"mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80",
|
||||
|
||||
// Test empty and special values
|
||||
"empty_string": "",
|
||||
"null_value": nil,
|
||||
"numeric_value": 42,
|
||||
"boolean_value": true,
|
||||
}),
|
||||
"route_state": mustMarshal(map[string]any{
|
||||
"routes": []any{
|
||||
map[string]any{
|
||||
"network": "203.0.113.0/24",
|
||||
"gateway": "203.0.113.1",
|
||||
"domains": []any{
|
||||
"route1.example.com",
|
||||
"route2.example.com",
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"network": "2001:db8::/32",
|
||||
"gateway": "2001:db8::1",
|
||||
"domains": []any{
|
||||
"route3.example.com",
|
||||
"route4.example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
// Test map with IP/CIDR keys
|
||||
"refCountMap": map[string]any{
|
||||
"203.0.113.1/32": map[string]any{
|
||||
"Count": 1,
|
||||
"Out": map[string]any{
|
||||
"IP": "192.168.0.1",
|
||||
"Intf": map[string]any{
|
||||
"Name": "eth0",
|
||||
"Index": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
"2001:db8::1/128": map[string]any{
|
||||
"Count": 1,
|
||||
"Out": map[string]any{
|
||||
"IP": "fe80::1",
|
||||
"Intf": map[string]any{
|
||||
"Name": "eth0",
|
||||
"Index": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
"10.0.0.1/32": map[string]any{ // private IP should remain unchanged
|
||||
"Count": 1,
|
||||
"Out": map[string]any{
|
||||
"IP": "192.168.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
|
||||
// Pre-seed the domains we need to verify in the test assertions
|
||||
anonymizer.AnonymizeDomain("test.example.com")
|
||||
anonymizer.AnonymizeDomain("nested.example.com")
|
||||
anonymizer.AnonymizeDomain("deep.example.com")
|
||||
anonymizer.AnonymizeDomain("array1.example.com")
|
||||
|
||||
err := anonymizeStateFile(&testState, anonymizer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Helper function to unmarshal and get nested values
|
||||
var state map[string]any
|
||||
err = json.Unmarshal(testState["test_state"], &state)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test null state remains unchanged
|
||||
require.Equal(t, "null", string(testState["null_state"]))
|
||||
|
||||
// Basic assertions
|
||||
assert.NotEqual(t, "203.0.113.1", state["public_ip"])
|
||||
assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged
|
||||
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
|
||||
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
|
||||
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
|
||||
assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged
|
||||
assert.NotEqual(t, "test.example.com", state["domain"])
|
||||
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
|
||||
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
|
||||
|
||||
// CIDR ranges
|
||||
assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"])
|
||||
assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved
|
||||
assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged
|
||||
assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged
|
||||
assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"])
|
||||
assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved
|
||||
|
||||
// Nested structures
|
||||
nested := state["nested"].(map[string]any)
|
||||
assert.NotEqual(t, "203.0.113.2", nested["ip"])
|
||||
assert.NotEqual(t, "nested.example.com", nested["domain"])
|
||||
moreNest := nested["more_nest"].(map[string]any)
|
||||
assert.NotEqual(t, "203.0.113.3", moreNest["ip"])
|
||||
assert.NotEqual(t, "deep.example.com", moreNest["domain"])
|
||||
|
||||
// Arrays
|
||||
strArray := state["string_array"].([]any)
|
||||
assert.NotEqual(t, "203.0.113.4", strArray[0])
|
||||
assert.NotEqual(t, "test1.example.com", strArray[1])
|
||||
assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain"))
|
||||
|
||||
objArray := state["object_array"].([]any)
|
||||
firstObj := objArray[0].(map[string]any)
|
||||
assert.NotEqual(t, "203.0.113.5", firstObj["ip"])
|
||||
assert.NotEqual(t, "array1.example.com", firstObj["domain"])
|
||||
|
||||
// Duplicate values should be anonymized consistently
|
||||
assert.Equal(t, state["public_ip"], state["duplicate_ip"])
|
||||
assert.Equal(t, state["domain"], state["duplicate_domain"])
|
||||
|
||||
// URIs
|
||||
assert.NotContains(t, state["stun_uri"], "stun.example.com")
|
||||
assert.NotContains(t, state["turns_uri"], "turns.example.com")
|
||||
assert.NotContains(t, state["http_uri"], "web.example.com")
|
||||
assert.NotContains(t, state["https_uri"], "secure.example.com")
|
||||
|
||||
// Non-IP strings should remain unchanged
|
||||
assert.Equal(t, "300.300.300.300", state["not_ip"])
|
||||
assert.Equal(t, "192.168", state["partial_ip"])
|
||||
assert.Equal(t, "1234.5678", state["ip_like_string"])
|
||||
|
||||
// Mixed content should have IPs and domains replaced
|
||||
mixedContent := state["mixed_content"].(string)
|
||||
assert.NotContains(t, mixedContent, "203.0.113.1")
|
||||
assert.NotContains(t, mixedContent, "test.example.com")
|
||||
assert.Contains(t, mixedContent, "Server at ")
|
||||
assert.Contains(t, mixedContent, " on port 80")
|
||||
|
||||
// Special values should remain unchanged
|
||||
assert.Equal(t, "", state["empty_string"])
|
||||
assert.Nil(t, state["null_value"])
|
||||
assert.Equal(t, float64(42), state["numeric_value"])
|
||||
assert.Equal(t, true, state["boolean_value"])
|
||||
|
||||
// Check route state
|
||||
var routeState map[string]any
|
||||
err = json.Unmarshal(testState["route_state"], &routeState)
|
||||
require.NoError(t, err)
|
||||
|
||||
routes := routeState["routes"].([]any)
|
||||
route1 := routes[0].(map[string]any)
|
||||
assert.NotEqual(t, "203.0.113.0/24", route1["network"])
|
||||
assert.Contains(t, route1["network"], "/24")
|
||||
assert.NotEqual(t, "203.0.113.1", route1["gateway"])
|
||||
domains := route1["domains"].([]any)
|
||||
assert.True(t, strings.HasSuffix(domains[0].(string), ".domain"))
|
||||
assert.True(t, strings.HasSuffix(domains[1].(string), ".domain"))
|
||||
|
||||
// Check map keys are anonymized
|
||||
refCountMap := routeState["refCountMap"].(map[string]any)
|
||||
hasPublicIPKey := false
|
||||
hasIPv6Key := false
|
||||
hasPrivateIPKey := false
|
||||
for key := range refCountMap {
|
||||
if strings.Contains(key, "203.0.113.1") {
|
||||
hasPublicIPKey = true
|
||||
}
|
||||
if strings.Contains(key, "2001:db8::1") {
|
||||
hasIPv6Key = true
|
||||
}
|
||||
if key == "10.0.0.1/32" {
|
||||
hasPrivateIPKey = true
|
||||
}
|
||||
}
|
||||
assert.False(t, hasPublicIPKey, "public IP in key should be anonymized")
|
||||
assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized")
|
||||
assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged")
|
||||
}
|
||||
|
||||
func mustMarshal(v any) json.RawMessage {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func TestAnonymizeNetworkMap(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
PeerConfig: &mgmProto.PeerConfig{
|
||||
Address: "203.0.113.5",
|
||||
Dns: "1.2.3.4",
|
||||
Fqdn: "peer1.corp.example.com",
|
||||
SshConfig: &mgmProto.SSHConfig{
|
||||
SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."),
|
||||
},
|
||||
},
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{
|
||||
AllowedIps: []string{
|
||||
"203.0.113.1/32",
|
||||
"2001:db8:1234::1/128",
|
||||
"192.168.1.1/32",
|
||||
"100.64.0.1/32",
|
||||
"10.0.0.1/32",
|
||||
},
|
||||
Fqdn: "peer2.corp.example.com",
|
||||
SshConfig: &mgmProto.SSHConfig{
|
||||
SshPubKey: []byte("ssh-rsa AAAAB3NzaC2..."),
|
||||
},
|
||||
},
|
||||
},
|
||||
Routes: []*mgmProto.Route{
|
||||
{
|
||||
Network: "197.51.100.0/24",
|
||||
Domains: []string{"prod.example.com", "staging.example.com"},
|
||||
NetID: "net-123abc",
|
||||
},
|
||||
},
|
||||
DNSConfig: &mgmProto.DNSConfig{
|
||||
NameServerGroups: []*mgmProto.NameServerGroup{
|
||||
{
|
||||
NameServers: []*mgmProto.NameServer{
|
||||
{IP: "8.8.8.8"},
|
||||
{IP: "1.1.1.1"},
|
||||
{IP: "203.0.113.53"},
|
||||
},
|
||||
Domains: []string{"example.com", "internal.example.com"},
|
||||
},
|
||||
},
|
||||
CustomZones: []*mgmProto.CustomZone{
|
||||
{
|
||||
Domain: "custom.example.com",
|
||||
Records: []*mgmProto.SimpleRecord{
|
||||
{
|
||||
Name: "www.custom.example.com",
|
||||
Type: 1,
|
||||
RData: "203.0.113.10",
|
||||
},
|
||||
{
|
||||
Name: "internal.custom.example.com",
|
||||
Type: 1,
|
||||
RData: "192.168.1.10",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create anonymizer with test addresses
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
|
||||
// Anonymize the network map
|
||||
err := anonymizeNetworkMap(networkMap, anonymizer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test PeerConfig anonymization
|
||||
peerCfg := networkMap.PeerConfig
|
||||
require.NotEqual(t, "203.0.113.5", peerCfg.Address)
|
||||
|
||||
// Verify DNS and FQDN are properly anonymized
|
||||
require.NotEqual(t, "1.2.3.4", peerCfg.Dns)
|
||||
require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn)
|
||||
require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain"))
|
||||
|
||||
// Verify SSH key is replaced
|
||||
require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey)
|
||||
|
||||
// Test RemotePeers anonymization
|
||||
remotePeer := networkMap.RemotePeers[0]
|
||||
|
||||
// Verify FQDN is anonymized
|
||||
require.NotEqual(t, "peer2.corp.example.com", remotePeer.Fqdn)
|
||||
require.True(t, strings.HasSuffix(remotePeer.Fqdn, ".domain"))
|
||||
|
||||
// Check that public IPs are anonymized but private IPs are preserved
|
||||
for _, allowedIP := range remotePeer.AllowedIps {
|
||||
ip, _, err := net.ParseCIDR(allowedIP)
|
||||
require.NoError(t, err)
|
||||
|
||||
if ip.IsPrivate() || isInCGNATRange(ip) {
|
||||
require.Contains(t, []string{
|
||||
"192.168.1.1/32",
|
||||
"100.64.0.1/32",
|
||||
"10.0.0.1/32",
|
||||
}, allowedIP)
|
||||
} else {
|
||||
require.NotContains(t, []string{
|
||||
"203.0.113.1/32",
|
||||
"2001:db8:1234::1/128",
|
||||
}, allowedIP)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Routes anonymization
|
||||
route := networkMap.Routes[0]
|
||||
require.NotEqual(t, "197.51.100.0/24", route.Network)
|
||||
for _, domain := range route.Domains {
|
||||
require.True(t, strings.HasSuffix(domain, ".domain"))
|
||||
require.NotContains(t, domain, "example.com")
|
||||
}
|
||||
|
||||
// Test DNS config anonymization
|
||||
dnsConfig := networkMap.DNSConfig
|
||||
nameServerGroup := dnsConfig.NameServerGroups[0]
|
||||
|
||||
// Verify well-known DNS servers are preserved
|
||||
require.Equal(t, "8.8.8.8", nameServerGroup.NameServers[0].IP)
|
||||
require.Equal(t, "1.1.1.1", nameServerGroup.NameServers[1].IP)
|
||||
|
||||
// Verify public DNS server is anonymized
|
||||
require.NotEqual(t, "203.0.113.53", nameServerGroup.NameServers[2].IP)
|
||||
|
||||
// Verify domains are anonymized
|
||||
for _, domain := range nameServerGroup.Domains {
|
||||
require.True(t, strings.HasSuffix(domain, ".domain"))
|
||||
require.NotContains(t, domain, "example.com")
|
||||
}
|
||||
|
||||
// Test CustomZones anonymization
|
||||
customZone := dnsConfig.CustomZones[0]
|
||||
require.True(t, strings.HasSuffix(customZone.Domain, ".domain"))
|
||||
require.NotContains(t, customZone.Domain, "example.com")
|
||||
|
||||
// Verify records are properly anonymized
|
||||
for _, record := range customZone.Records {
|
||||
require.True(t, strings.HasSuffix(record.Name, ".domain"))
|
||||
require.NotContains(t, record.Name, "example.com")
|
||||
|
||||
ip := net.ParseIP(record.RData)
|
||||
if ip != nil {
|
||||
if !ip.IsPrivate() {
|
||||
require.NotEqual(t, "203.0.113.10", record.RData)
|
||||
} else {
|
||||
require.Equal(t, "192.168.1.10", record.RData)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if IP is in CGNAT range
|
||||
func isInCGNATRange(ip net.IP) bool {
|
||||
cgnat := net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
return cgnat.Contains(ip)
|
||||
}
|
||||
|
||||
func TestAnonymizeFirewallRules(t *testing.T) {
|
||||
// TODO: Add ipv6
|
||||
|
||||
// Example iptables-save output
|
||||
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
|
||||
*filter
|
||||
:INPUT ACCEPT [0:0]
|
||||
:FORWARD ACCEPT [0:0]
|
||||
:OUTPUT ACCEPT [0:0]
|
||||
-A INPUT -s 192.168.1.0/24 -j ACCEPT
|
||||
-A INPUT -s 44.192.140.1/32 -j DROP
|
||||
-A FORWARD -s 10.0.0.0/8 -j DROP
|
||||
-A FORWARD -s 44.192.140.0/24 -d 52.84.12.34/24 -j ACCEPT
|
||||
COMMIT
|
||||
|
||||
*nat
|
||||
:PREROUTING ACCEPT [0:0]
|
||||
:INPUT ACCEPT [0:0]
|
||||
:OUTPUT ACCEPT [0:0]
|
||||
:POSTROUTING ACCEPT [0:0]
|
||||
-A POSTROUTING -s 192.168.100.0/24 -j MASQUERADE
|
||||
-A PREROUTING -d 44.192.140.10/32 -p tcp -m tcp --dport 80 -j DNAT --to-destination 192.168.1.10:80
|
||||
COMMIT`
|
||||
|
||||
// Example iptables -v -n -L output
|
||||
iptablesVerbose := `Chain INPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||
pkts bytes target prot opt in out source destination
|
||||
0 0 ACCEPT all -- * * 192.168.1.0/24 0.0.0.0/0
|
||||
100 1024 DROP all -- * * 44.192.140.1 0.0.0.0/0
|
||||
|
||||
Chain FORWARD (policy ACCEPT 0 packets, 0 bytes)
|
||||
pkts bytes target prot opt in out source destination
|
||||
0 0 DROP all -- * * 10.0.0.0/8 0.0.0.0/0
|
||||
25 256 ACCEPT all -- * * 44.192.140.0/24 52.84.12.34/24
|
||||
|
||||
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||
pkts bytes target prot opt in out source destination`
|
||||
|
||||
// Example nftables output
|
||||
nftablesRules := `table inet filter {
|
||||
chain input {
|
||||
type filter hook input priority filter; policy accept;
|
||||
ip saddr 192.168.1.1 accept
|
||||
ip saddr 44.192.140.1 drop
|
||||
}
|
||||
chain forward {
|
||||
type filter hook forward priority filter; policy accept;
|
||||
ip saddr 10.0.0.0/8 drop
|
||||
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
|
||||
}
|
||||
}`
|
||||
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
|
||||
// Test iptables-save anonymization
|
||||
anonIptablesSave := anonymizer.AnonymizeString(iptablesSave)
|
||||
|
||||
// Private IP addresses should remain unchanged
|
||||
assert.Contains(t, anonIptablesSave, "192.168.1.0/24")
|
||||
assert.Contains(t, anonIptablesSave, "10.0.0.0/8")
|
||||
assert.Contains(t, anonIptablesSave, "192.168.100.0/24")
|
||||
assert.Contains(t, anonIptablesSave, "192.168.1.10")
|
||||
|
||||
// Public IP addresses should be anonymized to the default range
|
||||
assert.NotContains(t, anonIptablesSave, "44.192.140.1")
|
||||
assert.NotContains(t, anonIptablesSave, "44.192.140.0/24")
|
||||
assert.NotContains(t, anonIptablesSave, "52.84.12.34")
|
||||
assert.Contains(t, anonIptablesSave, "198.51.100.") // Default anonymous range
|
||||
|
||||
// Structure should be preserved
|
||||
assert.Contains(t, anonIptablesSave, "*filter")
|
||||
assert.Contains(t, anonIptablesSave, ":INPUT ACCEPT [0:0]")
|
||||
assert.Contains(t, anonIptablesSave, "COMMIT")
|
||||
assert.Contains(t, anonIptablesSave, "-j MASQUERADE")
|
||||
assert.Contains(t, anonIptablesSave, "--dport 80")
|
||||
|
||||
// Test iptables verbose output anonymization
|
||||
anonIptablesVerbose := anonymizer.AnonymizeString(iptablesVerbose)
|
||||
|
||||
// Private IP addresses should remain unchanged
|
||||
assert.Contains(t, anonIptablesVerbose, "192.168.1.0/24")
|
||||
assert.Contains(t, anonIptablesVerbose, "10.0.0.0/8")
|
||||
|
||||
// Public IP addresses should be anonymized to the default range
|
||||
assert.NotContains(t, anonIptablesVerbose, "44.192.140.1")
|
||||
assert.NotContains(t, anonIptablesVerbose, "44.192.140.0/24")
|
||||
assert.NotContains(t, anonIptablesVerbose, "52.84.12.34")
|
||||
assert.Contains(t, anonIptablesVerbose, "198.51.100.") // Default anonymous range
|
||||
|
||||
// Structure and counters should be preserved
|
||||
assert.Contains(t, anonIptablesVerbose, "Chain INPUT (policy ACCEPT 0 packets, 0 bytes)")
|
||||
assert.Contains(t, anonIptablesVerbose, "100 1024 DROP")
|
||||
assert.Contains(t, anonIptablesVerbose, "pkts bytes target")
|
||||
|
||||
// Test nftables anonymization
|
||||
anonNftables := anonymizer.AnonymizeString(nftablesRules)
|
||||
|
||||
// Private IP addresses should remain unchanged
|
||||
assert.Contains(t, anonNftables, "192.168.1.1")
|
||||
assert.Contains(t, anonNftables, "10.0.0.0/8")
|
||||
|
||||
// Public IP addresses should be anonymized to the default range
|
||||
assert.NotContains(t, anonNftables, "44.192.140.1")
|
||||
assert.NotContains(t, anonNftables, "44.192.140.0/24")
|
||||
assert.NotContains(t, anonNftables, "52.84.12.34")
|
||||
assert.Contains(t, anonNftables, "198.51.100.") // Default anonymous range
|
||||
|
||||
// Structure should be preserved
|
||||
assert.Contains(t, anonNftables, "table inet filter {")
|
||||
assert.Contains(t, anonNftables, "chain input {")
|
||||
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||
}
|
||||
@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
|
||||
continue
|
||||
}
|
||||
|
||||
listOfDomains = append(listOfDomains, dConf.Domain)
|
||||
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||
}
|
||||
return listOfDomains
|
||||
}
|
||||
|
||||
@@ -75,12 +75,7 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
}
|
||||
|
||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
c.removeEntry(origPattern, priority)
|
||||
|
||||
// Check if handler implements SubdomainMatcher interface
|
||||
matchSubdomains := false
|
||||
@@ -133,30 +128,20 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
|
||||
c.removeEntry(pattern, priority)
|
||||
}
|
||||
|
||||
func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
||||
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||
return
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HasHandlers returns true if there are any handlers remaining for the given pattern
|
||||
func (c *HandlerChain) HasHandlers(pattern string) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||
for _, entry := range c.handlers {
|
||||
if strings.EqualFold(entry.Pattern, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
|
||||
@@ -443,14 +443,6 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||
for _, handler := range handlers {
|
||||
handler.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Verify handler exists check
|
||||
for priority, shouldExist := range tt.expectedCalls {
|
||||
if shouldExist {
|
||||
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
|
||||
"Handler chain should have handlers for pattern after removing priority %d", priority)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -470,45 +462,69 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(testQuery, dns.TypeA)
|
||||
|
||||
// Keep track of mocks for the final assertion in Step 4
|
||||
mocks := []*nbdns.MockSubdomainHandler{routeHandler, matchHandler, defaultHandler}
|
||||
|
||||
// Add handlers in mixed order
|
||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||
|
||||
// Test 1: Initial state with all three handlers
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Test 1: Initial state
|
||||
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Highest priority handler (routeHandler) should be called
|
||||
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
chain.ServeDNS(w1, r)
|
||||
routeHandler.AssertExpectations(t)
|
||||
|
||||
routeHandler.ExpectedCalls = nil
|
||||
routeHandler.Calls = nil
|
||||
matchHandler.ExpectedCalls = nil
|
||||
matchHandler.Calls = nil
|
||||
defaultHandler.ExpectedCalls = nil
|
||||
defaultHandler.Calls = nil
|
||||
|
||||
// Test 2: Remove highest priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||
assert.True(t, chain.HasHandlers(testDomain))
|
||||
|
||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Now middle priority handler (matchHandler) should be called
|
||||
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
chain.ServeDNS(w2, r)
|
||||
matchHandler.AssertExpectations(t)
|
||||
|
||||
matchHandler.ExpectedCalls = nil
|
||||
matchHandler.Calls = nil
|
||||
defaultHandler.ExpectedCalls = nil
|
||||
defaultHandler.Calls = nil
|
||||
|
||||
// Test 3: Remove middle priority handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||
assert.True(t, chain.HasHandlers(testDomain))
|
||||
|
||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
// Now lowest priority handler (defaultHandler) should be called
|
||||
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||
|
||||
chain.ServeDNS(w, r)
|
||||
chain.ServeDNS(w3, r)
|
||||
defaultHandler.AssertExpectations(t)
|
||||
|
||||
defaultHandler.ExpectedCalls = nil
|
||||
defaultHandler.Calls = nil
|
||||
|
||||
// Test 4: Remove last handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||
|
||||
assert.False(t, chain.HasHandlers(testDomain))
|
||||
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
||||
|
||||
for _, m := range mocks {
|
||||
m.AssertNumberOfCalls(t, "ServeDNS", 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
@@ -830,3 +846,165 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addPattern string
|
||||
removePattern string
|
||||
queryPattern string
|
||||
shouldBeRemoved bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "exact same pattern",
|
||||
addPattern: "example.com.",
|
||||
removePattern: "example.com.",
|
||||
queryPattern: "example.com.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding and removing with identical patterns",
|
||||
},
|
||||
{
|
||||
name: "case difference",
|
||||
addPattern: "Example.Com.",
|
||||
removePattern: "EXAMPLE.COM.",
|
||||
queryPattern: "example.com.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding with mixed case, removing with uppercase",
|
||||
},
|
||||
{
|
||||
name: "reversed case difference",
|
||||
addPattern: "EXAMPLE.ORG.",
|
||||
removePattern: "example.org.",
|
||||
queryPattern: "example.org.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding with uppercase, removing with lowercase",
|
||||
},
|
||||
{
|
||||
name: "add wildcard, remove wildcard",
|
||||
addPattern: "*.example.com.",
|
||||
removePattern: "*.example.com.",
|
||||
queryPattern: "sub.example.com.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding and removing with identical wildcard patterns",
|
||||
},
|
||||
{
|
||||
name: "add wildcard, remove transformed pattern",
|
||||
addPattern: "*.example.net.",
|
||||
removePattern: "example.net.",
|
||||
queryPattern: "sub.example.net.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding with wildcard, removing with non-wildcard pattern",
|
||||
},
|
||||
{
|
||||
name: "add transformed pattern, remove wildcard",
|
||||
addPattern: "example.io.",
|
||||
removePattern: "*.example.io.",
|
||||
queryPattern: "example.io.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding with non-wildcard pattern, removing with wildcard pattern",
|
||||
},
|
||||
{
|
||||
name: "trailing dot difference",
|
||||
addPattern: "example.dev",
|
||||
removePattern: "example.dev.",
|
||||
queryPattern: "example.dev.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding without trailing dot, removing with trailing dot",
|
||||
},
|
||||
{
|
||||
name: "reversed trailing dot difference",
|
||||
addPattern: "example.app.",
|
||||
removePattern: "example.app",
|
||||
queryPattern: "example.app.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding with trailing dot, removing without trailing dot",
|
||||
},
|
||||
{
|
||||
name: "mixed case and wildcard",
|
||||
addPattern: "*.Example.Site.",
|
||||
removePattern: "*.EXAMPLE.SITE.",
|
||||
queryPattern: "sub.example.site.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding mixed case wildcard, removing uppercase wildcard",
|
||||
},
|
||||
{
|
||||
name: "root zone",
|
||||
addPattern: ".",
|
||||
removePattern: ".",
|
||||
queryPattern: "random.domain.",
|
||||
shouldBeRemoved: true,
|
||||
description: "Adding and removing root zone",
|
||||
},
|
||||
{
|
||||
name: "wrong domain",
|
||||
addPattern: "example.com.",
|
||||
removePattern: "different.com.",
|
||||
queryPattern: "example.com.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding one domain, trying to remove a different domain",
|
||||
},
|
||||
{
|
||||
name: "subdomain mismatch",
|
||||
addPattern: "sub.example.com.",
|
||||
removePattern: "example.com.",
|
||||
queryPattern: "sub.example.com.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding subdomain, trying to remove parent domain",
|
||||
},
|
||||
{
|
||||
name: "parent domain mismatch",
|
||||
addPattern: "example.com.",
|
||||
removePattern: "sub.example.com.",
|
||||
queryPattern: "example.com.",
|
||||
shouldBeRemoved: false,
|
||||
description: "Adding parent domain, trying to remove subdomain",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
handler := &nbdns.MockHandler{}
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||
|
||||
// First verify no handler is called before adding any
|
||||
chain.ServeDNS(w, r)
|
||||
handler.AssertNotCalled(t, "ServeDNS")
|
||||
|
||||
// Add handler
|
||||
chain.AddHandler(tt.addPattern, handler, nbdns.PriorityDefault)
|
||||
|
||||
// Verify handler is called after adding
|
||||
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||
chain.ServeDNS(w, r)
|
||||
handler.AssertExpectations(t)
|
||||
|
||||
// Reset mock for the next test
|
||||
handler.ExpectedCalls = nil
|
||||
|
||||
// Remove handler
|
||||
chain.RemoveHandler(tt.removePattern, nbdns.PriorityDefault)
|
||||
|
||||
// Set up expectations based on whether removal should succeed
|
||||
if !tt.shouldBeRemoved {
|
||||
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||
}
|
||||
|
||||
// Test if handler is still called after removal attempt
|
||||
chain.ServeDNS(w, r)
|
||||
|
||||
if tt.shouldBeRemoved {
|
||||
handler.AssertNotCalled(t, "ServeDNS",
|
||||
"Handler should not be called after successful removal with pattern %q",
|
||||
tt.removePattern)
|
||||
} else {
|
||||
handler.AssertExpectations(t)
|
||||
handler.ExpectedCalls = nil
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
@@ -12,8 +14,8 @@ import (
|
||||
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||
|
||||
const (
|
||||
ipv4ReverseZone = ".in-addr.arpa"
|
||||
ipv6ReverseZone = ".ip6.arpa"
|
||||
ipv4ReverseZone = ".in-addr.arpa."
|
||||
ipv6ReverseZone = ".ip6.arpa."
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
@@ -103,7 +105,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
|
||||
for _, domain := range nsConfig.Domains {
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(domain, "."),
|
||||
Domain: strings.ToLower(dns.Fqdn(domain)),
|
||||
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
||||
})
|
||||
}
|
||||
@@ -112,7 +114,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
||||
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||
MatchOnly: matchOnly,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
continue
|
||||
}
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, dConf.Domain)
|
||||
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||
}
|
||||
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
|
||||
@@ -17,9 +17,12 @@ import (
|
||||
|
||||
var (
|
||||
userenv = syscall.NewLazyDLL("userenv.dll")
|
||||
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
|
||||
|
||||
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
|
||||
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
|
||||
|
||||
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -97,9 +100,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
continue
|
||||
}
|
||||
if !dConf.MatchOnly {
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||
}
|
||||
matchDomains = append(matchDomains, "."+dConf.Domain)
|
||||
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
||||
}
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
@@ -116,6 +119,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
}
|
||||
|
||||
if err := r.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -184,6 +191,26 @@ func (r *registryConfigurator) string() string {
|
||||
return "registry"
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) flushDNSCache() error {
|
||||
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Errorf("Recovered from panic in flushDNSCache: %v", rec)
|
||||
}
|
||||
}()
|
||||
|
||||
ret, _, err := dnsFlushResolverCacheFn.Call()
|
||||
if ret == 0 {
|
||||
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
|
||||
}
|
||||
return fmt.Errorf("DnsFlushResolverCache failed")
|
||||
}
|
||||
|
||||
log.Info("flushed DNS cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
|
||||
return fmt.Errorf("update search domains: %w", err)
|
||||
@@ -236,6 +263,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
if err := r.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -71,6 +71,12 @@ func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
|
||||
|
||||
value, found := d.records.Load(key)
|
||||
if !found {
|
||||
// alternatively check if we have a cname
|
||||
if question.Qtype != dns.TypeCNAME {
|
||||
r.Question[0].Qtype = dns.TypeCNAME
|
||||
return d.lookupRecords(r)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// MockServer is the mock instance of a dns server
|
||||
@@ -13,17 +14,17 @@ type MockServer struct {
|
||||
InitializeFunc func() error
|
||||
StopFunc func()
|
||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||
RegisterHandlerFunc func([]string, dns.Handler, int)
|
||||
DeregisterHandlerFunc func([]string, int)
|
||||
RegisterHandlerFunc func(domain.List, dns.Handler, int)
|
||||
DeregisterHandlerFunc func(domain.List, int)
|
||||
}
|
||||
|
||||
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||
if m.RegisterHandlerFunc != nil {
|
||||
m.RegisterHandlerFunc(domains, handler, priority)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
|
||||
func (m *MockServer) DeregisterHandler(domains domain.List, priority int) {
|
||||
if m.DeregisterHandlerFunc != nil {
|
||||
m.DeregisterHandlerFunc(domains, priority)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -126,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
|
||||
continue
|
||||
}
|
||||
if dConf.MatchOnly {
|
||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain))
|
||||
matchDomains = append(matchDomains, "~."+dConf.Domain)
|
||||
continue
|
||||
}
|
||||
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain))
|
||||
searchDomains = append(searchDomains, dConf.Domain)
|
||||
}
|
||||
|
||||
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
@@ -18,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||
@@ -32,8 +35,8 @@ type IosDnsManager interface {
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
RegisterHandler(domains []string, handler dns.Handler, priority int)
|
||||
DeregisterHandler(domains []string, priority int)
|
||||
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||
DeregisterHandler(domains domain.List, priority int)
|
||||
Initialize() error
|
||||
Stop()
|
||||
DnsIP() string
|
||||
@@ -65,6 +68,7 @@ type DefaultServer struct {
|
||||
previousConfigHash uint64
|
||||
currentConfig HostDNSConfig
|
||||
handlerChain *HandlerChain
|
||||
extraDomains map[domain.Domain]int
|
||||
|
||||
// permanent related properties
|
||||
permanent bool
|
||||
@@ -164,13 +168,15 @@ func newDefaultServer(
|
||||
stateManager *statemanager.Manager,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
handlerChain := NewHandlerChain()
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: NewHandlerChain(),
|
||||
handlerChain: handlerChain,
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
@@ -181,14 +187,26 @@ func newDefaultServer(
|
||||
hostsDNSHolder: newHostsDNSHolder(),
|
||||
}
|
||||
|
||||
// register with root zone, handler chain takes care of the routing
|
||||
dnsService.RegisterMux(".", handlerChain)
|
||||
|
||||
return defaultServer
|
||||
}
|
||||
|
||||
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
||||
// RegisterHandler registers a handler for the given domains with the given priority.
|
||||
// Any previously registered handler for the same domain and priority will be replaced.
|
||||
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.registerHandler(domains, handler, priority)
|
||||
s.registerHandler(domains.ToPunycodeList(), handler, priority)
|
||||
|
||||
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
|
||||
for _, domain := range domains {
|
||||
// convert to zone with simple ref counter
|
||||
s.extraDomains[toZone(domain)]++
|
||||
}
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||
@@ -200,15 +218,23 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
|
||||
continue
|
||||
}
|
||||
s.handlerChain.AddHandler(domain, handler, priority)
|
||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
|
||||
// DeregisterHandler deregisters the handler for the given domains with the given priority.
|
||||
func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.deregisterHandler(domains, priority)
|
||||
s.deregisterHandler(domains.ToPunycodeList(), priority)
|
||||
for _, domain := range domains {
|
||||
zone := toZone(domain)
|
||||
s.extraDomains[zone]--
|
||||
if s.extraDomains[zone] <= 0 {
|
||||
delete(s.extraDomains, zone)
|
||||
}
|
||||
}
|
||||
s.applyHostConfig()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
@@ -221,11 +247,6 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||
}
|
||||
|
||||
s.handlerChain.RemoveHandler(domain, priority)
|
||||
|
||||
// Only deregister from service if no handlers remain
|
||||
if !s.handlerChain.HasHandlers(domain) {
|
||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,6 +307,8 @@ func (s *DefaultServer) Stop() {
|
||||
}
|
||||
|
||||
s.service.Stop()
|
||||
|
||||
maps.Clear(s.extraDomains)
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||
@@ -390,7 +413,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be Disabled, we stop the listener or fake resolver
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if update.ServiceEnable {
|
||||
_ = s.service.Listen()
|
||||
if err := s.service.Listen(); err != nil {
|
||||
log.Errorf("failed to start DNS service: %v", err)
|
||||
}
|
||||
} else if !s.permanent {
|
||||
s.service.Stop()
|
||||
}
|
||||
@@ -413,17 +438,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
hostUpdate := s.currentConfig
|
||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||
hostUpdate.RouteAll = false
|
||||
s.currentConfig.RouteAll = false
|
||||
}
|
||||
|
||||
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
|
||||
log.Error(err)
|
||||
s.handleErrNoGroupaAll(err)
|
||||
}
|
||||
s.applyHostConfig()
|
||||
|
||||
go func() {
|
||||
// persist dns state right away
|
||||
@@ -441,6 +462,43 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) applyHostConfig() {
|
||||
if s.hostManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// prevent reapplying config if we're shutting down
|
||||
if s.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
config := s.currentConfig
|
||||
|
||||
existingDomains := make(map[string]struct{})
|
||||
for _, d := range config.Domains {
|
||||
existingDomains[d.Domain] = struct{}{}
|
||||
}
|
||||
|
||||
// add extra domains only if they're not already in the config
|
||||
for domain := range s.extraDomains {
|
||||
domainStr := domain.PunycodeString()
|
||||
|
||||
if _, exists := existingDomains[domainStr]; !exists {
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: domainStr,
|
||||
MatchOnly: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("extra match domains: %v", s.extraDomains)
|
||||
|
||||
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||
s.handleErrNoGroupaAll(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
||||
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
|
||||
return
|
||||
@@ -690,10 +748,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
s.handleErrNoGroupaAll(err)
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
s.applyHostConfig()
|
||||
|
||||
go func() {
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
@@ -728,12 +783,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||
}
|
||||
|
||||
if s.hostManager != nil {
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
s.handleErrNoGroupaAll(err)
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
}
|
||||
s.applyHostConfig()
|
||||
|
||||
s.updateNSState(nsGroup, nil, true)
|
||||
}
|
||||
@@ -836,3 +886,13 @@ func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func toZone(d domain.Domain) domain.Domain {
|
||||
return domain.Domain(
|
||||
nbdns.NormalizeZone(
|
||||
dns.Fqdn(
|
||||
strings.ToLower(d.PunycodeString()),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -29,16 +29,17 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type mocWGIface struct {
|
||||
filter device.PacketFilter
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Name() string {
|
||||
panic("implement me")
|
||||
return "utun2301"
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Address() wgaddr.Address {
|
||||
@@ -1448,3 +1449,497 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialConfig nbdns.Config
|
||||
registerDomains []domain.List
|
||||
deregisterDomains []domain.List
|
||||
finalConfig nbdns.Config
|
||||
expectedDomains []string
|
||||
expectedMatchOnly []string
|
||||
applyHostConfigCall int
|
||||
}{
|
||||
{
|
||||
name: "Register domains before config update",
|
||||
registerDomains: []domain.List{
|
||||
{"extra1.example.com", "extra2.example.com"},
|
||||
},
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
{
|
||||
name: "Register domains after config update",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra1.example.com", "extra2.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra1.example.com.",
|
||||
"extra2.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
{
|
||||
name: "Register overlapping domains",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
{Domain: "overlap.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "overlap.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"overlap.example.com.",
|
||||
"extra.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
{
|
||||
name: "Register and deregister domains",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra1.example.com", "extra2.example.com"},
|
||||
{"extra3.example.com", "extra4.example.com"},
|
||||
},
|
||||
deregisterDomains: []domain.List{
|
||||
{"extra1.example.com", "extra3.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"extra2.example.com.",
|
||||
"extra4.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra2.example.com.",
|
||||
"extra4.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 4,
|
||||
},
|
||||
{
|
||||
name: "Register domains with ref counter",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "duplicate.example.com"},
|
||||
{"other.example.com", "duplicate.example.com"},
|
||||
},
|
||||
deregisterDomains: []domain.List{
|
||||
{"duplicate.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"extra.example.com.",
|
||||
"other.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
"other.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 4,
|
||||
},
|
||||
{
|
||||
name: "Config update with new domains after registration",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "duplicate.example.com"},
|
||||
},
|
||||
finalConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
{Domain: "newconfig.example.com"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"config.example.com.",
|
||||
"newconfig.example.com.",
|
||||
"extra.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 3,
|
||||
},
|
||||
{
|
||||
name: "Deregister domain that is part of customZones",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
{Domain: "protected.example.com"},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "protected.example.com"},
|
||||
},
|
||||
deregisterDomains: []domain.List{
|
||||
{"protected.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"extra.example.com.",
|
||||
"config.example.com.",
|
||||
"protected.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 3,
|
||||
},
|
||||
{
|
||||
name: "Register domain that is part of nameserver group",
|
||||
initialConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Domains: []string{"ns.example.com", "overlap.ns.example.com"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
registerDomains: []domain.List{
|
||||
{"extra.example.com", "overlap.ns.example.com"},
|
||||
},
|
||||
expectedDomains: []string{
|
||||
"ns.example.com.",
|
||||
"overlap.ns.example.com.",
|
||||
"extra.example.com.",
|
||||
},
|
||||
expectedMatchOnly: []string{
|
||||
"ns.example.com.",
|
||||
"overlap.ns.example.com.",
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var capturedConfigs []HostDNSConfig
|
||||
mockHostConfig := &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
capturedConfigs = append(capturedConfigs, config)
|
||||
return nil
|
||||
},
|
||||
restoreHostDNSFunc: func() error {
|
||||
return nil
|
||||
},
|
||||
supportCustomPortFunc: func() bool {
|
||||
return true
|
||||
},
|
||||
stringFunc: func() string {
|
||||
return "mock"
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc := &mockService{}
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
wgInterface: &mocWGIface{},
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
// Apply initial configuration
|
||||
if tt.initialConfig.ServiceEnable {
|
||||
err := server.applyConfiguration(tt.initialConfig)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Register domains
|
||||
for _, domains := range tt.registerDomains {
|
||||
server.RegisterHandler(domains, &MockHandler{}, PriorityDefault)
|
||||
}
|
||||
|
||||
// Deregister domains if specified
|
||||
for _, domains := range tt.deregisterDomains {
|
||||
server.DeregisterHandler(domains, PriorityDefault)
|
||||
}
|
||||
|
||||
// Apply final configuration if specified
|
||||
if tt.finalConfig.ServiceEnable {
|
||||
err := server.applyConfiguration(tt.finalConfig)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify number of calls
|
||||
assert.Equal(t, tt.applyHostConfigCall, len(capturedConfigs),
|
||||
"Expected %d calls to applyDNSConfig, got %d", tt.applyHostConfigCall, len(capturedConfigs))
|
||||
|
||||
// Get the last applied config
|
||||
lastConfig := capturedConfigs[len(capturedConfigs)-1]
|
||||
|
||||
// Check all expected domains are present
|
||||
domainMap := make(map[string]bool)
|
||||
matchOnlyMap := make(map[string]bool)
|
||||
|
||||
for _, d := range lastConfig.Domains {
|
||||
domainMap[d.Domain] = true
|
||||
if d.MatchOnly {
|
||||
matchOnlyMap[d.Domain] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Verify expected domains
|
||||
for _, d := range tt.expectedDomains {
|
||||
assert.True(t, domainMap[d], "Expected domain %s not found in final config", d)
|
||||
}
|
||||
|
||||
// Verify match-only domains
|
||||
for _, d := range tt.expectedMatchOnly {
|
||||
assert.True(t, matchOnlyMap[d], "Expected match-only domain %s not found in final config", d)
|
||||
}
|
||||
|
||||
// Verify no unexpected domains
|
||||
assert.Equal(t, len(tt.expectedDomains), len(domainMap), "Unexpected number of domains in final config")
|
||||
assert.Equal(t, len(tt.expectedMatchOnly), len(matchOnlyMap), "Unexpected number of match-only domains in final config")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraDomainsRefCounting(t *testing.T) {
|
||||
mockHostConfig := &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
return nil
|
||||
},
|
||||
restoreHostDNSFunc: func() error {
|
||||
return nil
|
||||
},
|
||||
supportCustomPortFunc: func() bool {
|
||||
return true
|
||||
},
|
||||
stringFunc: func() string {
|
||||
return "mock"
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc := &mockService{}
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
// Register domains from different handlers with same domain
|
||||
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
|
||||
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
|
||||
|
||||
// Verify refcount is 2
|
||||
zoneKey := toZone("shared.example.com")
|
||||
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
|
||||
|
||||
// Deregister one handler
|
||||
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
|
||||
|
||||
// Verify refcount is 1
|
||||
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
|
||||
|
||||
// Deregister the other handler
|
||||
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityDNSRoute)
|
||||
|
||||
// Verify domain is removed
|
||||
_, exists := server.extraDomains[zoneKey]
|
||||
assert.False(t, exists, "Domain should be removed after deregistering all handlers")
|
||||
}
|
||||
|
||||
func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
||||
var capturedConfig HostDNSConfig
|
||||
mockHostConfig := &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
capturedConfig = config
|
||||
return nil
|
||||
},
|
||||
restoreHostDNSFunc: func() error {
|
||||
return nil
|
||||
},
|
||||
supportCustomPortFunc: func() bool {
|
||||
return true
|
||||
},
|
||||
stringFunc: func() string {
|
||||
return "mock"
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc := &mockService{}
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
server.RegisterHandler(domain.List{"extra.example.com"}, &MockHandler{}, PriorityDefault)
|
||||
|
||||
initialConfig := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
}
|
||||
err := server.applyConfiguration(initialConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var domains []string
|
||||
for _, d := range capturedConfig.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
}
|
||||
assert.Contains(t, domains, "config.example.com.")
|
||||
assert.Contains(t, domains, "extra.example.com.")
|
||||
|
||||
// Now apply a new configuration with overlapping domain
|
||||
updatedConfig := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
{Domain: "extra.example.com"},
|
||||
},
|
||||
}
|
||||
err = server.applyConfiguration(updatedConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify both domains are in config, but no duplicates
|
||||
domains = []string{}
|
||||
matchOnlyCount := 0
|
||||
for _, d := range capturedConfig.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
if d.MatchOnly {
|
||||
matchOnlyCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Contains(t, domains, "config.example.com.")
|
||||
assert.Contains(t, domains, "extra.example.com.")
|
||||
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
|
||||
|
||||
// Extra domain should no longer be marked as match-only when in config
|
||||
matchOnlyDomain := ""
|
||||
for _, d := range capturedConfig.Domains {
|
||||
if d.Domain == "extra.example.com." && d.MatchOnly {
|
||||
matchOnlyDomain = d.Domain
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Empty(t, matchOnlyDomain, "Domain should not be match-only when included in config")
|
||||
}
|
||||
|
||||
func TestDomainCaseHandling(t *testing.T) {
|
||||
var capturedConfig HostDNSConfig
|
||||
mockHostConfig := &mockHostConfigurator{
|
||||
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
capturedConfig = config
|
||||
return nil
|
||||
},
|
||||
restoreHostDNSFunc: func() error {
|
||||
return nil
|
||||
},
|
||||
supportCustomPortFunc: func() bool {
|
||||
return true
|
||||
},
|
||||
stringFunc: func() string {
|
||||
return "mock"
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc := &mockService{}
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: mockHostConfig,
|
||||
localResolver: &localResolver{},
|
||||
service: mockSvc,
|
||||
statusRecorder: peer.NewRecorder("test"),
|
||||
extraDomains: make(map[domain.Domain]int),
|
||||
}
|
||||
|
||||
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
|
||||
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
|
||||
|
||||
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
|
||||
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{Domain: "config.example.com"},
|
||||
},
|
||||
}
|
||||
err := server.applyConfiguration(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var domains []string
|
||||
for _, d := range capturedConfig.Domains {
|
||||
domains = append(domains, d.Domain)
|
||||
}
|
||||
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
||||
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
@@ -38,7 +37,6 @@ const (
|
||||
|
||||
type systemdDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
ifaceName string
|
||||
}
|
||||
|
||||
@@ -112,7 +110,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
continue
|
||||
}
|
||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||
Domain: dns.Fqdn(dConf.Domain),
|
||||
Domain: dConf.Domain,
|
||||
MatchOnly: dConf.MatchOnly,
|
||||
})
|
||||
|
||||
@@ -124,18 +122,19 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
}
|
||||
|
||||
if config.RouteAll {
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting link as default dns router, failed with error: %w", err)
|
||||
return fmt.Errorf("set link as default dns router: %w", err)
|
||||
}
|
||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||
Domain: nbdns.RootZone,
|
||||
MatchOnly: true,
|
||||
})
|
||||
s.routingAll = true
|
||||
} else if s.routingAll {
|
||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
if err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil {
|
||||
return fmt.Errorf("remove link as default dns router: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
state := &ShutdownState{
|
||||
@@ -151,6 +150,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -163,7 +167,8 @@ func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdD
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting domains configuration failed with error: %w", err)
|
||||
}
|
||||
return s.flushCaches()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
||||
@@ -183,10 +188,14 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
||||
}
|
||||
|
||||
return s.flushCaches()
|
||||
if err := s.flushDNSCache(); err != nil {
|
||||
log.Errorf("failed to flush DNS cache: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) flushCaches() error {
|
||||
func (s *systemdDbusConfigurator) flushDNSCache() error {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)
|
||||
|
||||
@@ -18,14 +18,16 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
UpstreamTimeout = 15 * time.Second
|
||||
|
||||
failsTillDeact = int32(5)
|
||||
reactivatePeriod = 30 * time.Second
|
||||
upstreamTimeout = 15 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
@@ -66,7 +68,7 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
domain: domain,
|
||||
upstreamTimeout: upstreamTimeout,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
statusRecorder: statusRecorder,
|
||||
@@ -106,9 +108,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}()
|
||||
|
||||
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||
if r.Extra == nil {
|
||||
r.SetEdns0(4096, false)
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
@@ -336,3 +336,51 @@ func (u *upstreamResolverBase) testNameserver(server string, timeout time.Durati
|
||||
_, _, err := u.upstreamClient.exchange(ctx, server, r)
|
||||
return err
|
||||
}
|
||||
|
||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||
// MTU - ip + udp headers
|
||||
// Note: this could be sent out on an interface that is not ours, but our MTU should always be lower.
|
||||
client.UDPSize = iface.DefaultMTU - (60 + 8)
|
||||
|
||||
var (
|
||||
rm *dns.Msg
|
||||
t time.Duration
|
||||
err error
|
||||
)
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with udp: %w", err)
|
||||
}
|
||||
|
||||
if rm == nil || !rm.MsgHdr.Truncated {
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
|
||||
client.Net = "tcp"
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
|
||||
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
|
||||
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
||||
|
||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
timeout := upstreamTimeout
|
||||
timeout := UpstreamTimeout
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
timeout = time.Until(deadline)
|
||||
}
|
||||
|
||||
@@ -34,6 +34,5 @@ func newUpstreamResolver(
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
upstreamExchangeClient := &dns.Client{}
|
||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
||||
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
||||
}
|
||||
|
||||
timeout := upstreamTimeout
|
||||
timeout := UpstreamTimeout
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
timeout = time.Until(deadline)
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
}
|
||||
|
||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||
return client.Exchange(r, upstream)
|
||||
return ExchangeWithFallback(nil, client, r, upstream)
|
||||
}
|
||||
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
name: "Should Resolve A Record",
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
timeout: upstreamTimeout,
|
||||
timeout: UpstreamTimeout,
|
||||
expectedAnswer: "1.1.1.1",
|
||||
},
|
||||
{
|
||||
@@ -48,7 +48,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||
cancelCTX: true,
|
||||
timeout: upstreamTimeout,
|
||||
timeout: UpstreamTimeout,
|
||||
responseShouldBeNil: true,
|
||||
},
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
r: new(dns.Msg),
|
||||
rtt: time.Millisecond,
|
||||
},
|
||||
upstreamTimeout: upstreamTimeout,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
failsTillDeact: failsTillDeact,
|
||||
}
|
||||
|
||||
@@ -3,34 +3,53 @@ package dnsfwd
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||
const upstreamTimeout = 15 * time.Second
|
||||
|
||||
type DNSForwarder struct {
|
||||
listenAddress string
|
||||
ttl uint32
|
||||
domains []string
|
||||
listenAddress string
|
||||
ttl uint32
|
||||
statusRecorder *peer.Status
|
||||
|
||||
dnsServer *dns.Server
|
||||
mux *dns.ServeMux
|
||||
|
||||
mutex sync.RWMutex
|
||||
fwdEntries []*ForwarderEntry
|
||||
firewall firewall.Manager
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder {
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder {
|
||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||
return &DNSForwarder{
|
||||
listenAddress: listenAddress,
|
||||
ttl: ttl,
|
||||
listenAddress: listenAddress,
|
||||
ttl: ttl,
|
||||
firewall: firewall,
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Listen(domains []string) error {
|
||||
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
@@ -42,23 +61,35 @@ func (f *DNSForwarder) Listen(domains []string) error {
|
||||
f.dnsServer = dnsServer
|
||||
f.mux = mux
|
||||
|
||||
f.UpdateDomains(domains)
|
||||
f.UpdateDomains(entries)
|
||||
|
||||
return dnsServer.ListenAndServe()
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) UpdateDomains(domains []string) {
|
||||
log.Debugf("Updating domains from %v to %v", f.domains, domains)
|
||||
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
for _, d := range f.domains {
|
||||
f.mux.HandleRemove(d)
|
||||
if f.mux == nil {
|
||||
log.Debug("DNS mux is nil, skipping domain update")
|
||||
f.fwdEntries = entries
|
||||
return
|
||||
}
|
||||
|
||||
newDomains := filterDomains(domains)
|
||||
oldDomains := filterDomains(f.fwdEntries)
|
||||
|
||||
for _, d := range oldDomains {
|
||||
f.mux.HandleRemove(d.PunycodeString())
|
||||
}
|
||||
|
||||
newDomains := filterDomains(entries)
|
||||
for _, d := range newDomains {
|
||||
f.mux.HandleFunc(d, f.handleDNSQuery)
|
||||
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery)
|
||||
}
|
||||
f.domains = newDomains
|
||||
|
||||
f.fwdEntries = entries
|
||||
|
||||
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
@@ -72,48 +103,116 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
if len(query.Question) == 0 {
|
||||
return
|
||||
}
|
||||
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
||||
query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass)
|
||||
|
||||
question := query.Question[0]
|
||||
domain := question.Name
|
||||
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
||||
question.Name, question.Qtype, question.Qclass)
|
||||
|
||||
domain := strings.ToLower(question.Name)
|
||||
|
||||
resp := query.SetReply(query)
|
||||
var network string
|
||||
switch question.Qtype {
|
||||
case dns.TypeA:
|
||||
network = "ip4"
|
||||
case dns.TypeAAAA:
|
||||
network = "ip6"
|
||||
default:
|
||||
// TODO: Handle other types
|
||||
|
||||
ips, err := net.LookupIP(domain)
|
||||
if err != nil {
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &dnsErr):
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if dnsErr.IsNotFound {
|
||||
// Pass through NXDOMAIN
|
||||
resp.Rcode = dns.RcodeNameError
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
default:
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
resp.Rcode = dns.RcodeNotImplemented
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||
defer cancel()
|
||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
|
||||
if err != nil {
|
||||
f.handleDNSError(w, resp, domain, err)
|
||||
return
|
||||
}
|
||||
|
||||
f.updateInternalState(domain, ips)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) {
|
||||
var prefixes []netip.Prefix
|
||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
||||
if mostSpecificResId != "" {
|
||||
for _, ip := range ips {
|
||||
var prefix netip.Prefix
|
||||
if ip.Is4() {
|
||||
prefix = netip.PrefixFrom(ip, 32)
|
||||
} else {
|
||||
prefix = netip.PrefixFrom(ip, 128)
|
||||
}
|
||||
prefixes = append(prefixes, prefix)
|
||||
f.statusRecorder.AddResolvedIPLookupEntry(prefix, mostSpecificResId)
|
||||
}
|
||||
}
|
||||
|
||||
if f.firewall != nil {
|
||||
f.updateFirewall(matchingEntries, prefixes)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixes []netip.Prefix) {
|
||||
var merr *multierror.Error
|
||||
for _, entry := range matchingEntries {
|
||||
if err := f.firewall.UpdateSet(entry.Set, prefixes); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("update set for domain=%s: %w", entry.Domain, err))
|
||||
}
|
||||
}
|
||||
if merr != nil {
|
||||
log.Errorf("failed to update firewall sets (%d/%d): %v",
|
||||
len(merr.Errors),
|
||||
len(matchingEntries),
|
||||
nberrors.FormatErrorOrNil(merr))
|
||||
}
|
||||
}
|
||||
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) {
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &dnsErr):
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if dnsErr.IsNotFound {
|
||||
// Pass through NXDOMAIN
|
||||
resp.Rcode = dns.RcodeNameError
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
default:
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
|
||||
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
|
||||
for _, ip := range ips {
|
||||
var respRecord dns.RR
|
||||
if ip.To4() == nil {
|
||||
if ip.Is6() {
|
||||
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
||||
rr := dns.AAAA{
|
||||
AAAA: ip,
|
||||
AAAA: ip.AsSlice(),
|
||||
Hdr: dns.RR_Header{
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
@@ -125,7 +224,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
} else {
|
||||
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
||||
rr := dns.A{
|
||||
A: ip,
|
||||
A: ip.AsSlice(),
|
||||
Hdr: dns.RR_Header{
|
||||
Name: domain,
|
||||
Rrtype: dns.TypeA,
|
||||
@@ -137,21 +236,55 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
||||
}
|
||||
resp.Answer = append(resp.Answer, respRecord)
|
||||
}
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write DNS response: %v", err)
|
||||
// getMatchingEntries retrieves the resource IDs for a given domain.
|
||||
// It returns the most specific match and all matching resource IDs.
|
||||
func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*ForwarderEntry) {
|
||||
var selectedResId route.ResID
|
||||
var bestScore int
|
||||
var matches []*ForwarderEntry
|
||||
|
||||
f.mutex.RLock()
|
||||
defer f.mutex.RUnlock()
|
||||
|
||||
for _, entry := range f.fwdEntries {
|
||||
var score int
|
||||
pattern := entry.Domain.PunycodeString()
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(pattern, "*."):
|
||||
baseDomain := strings.TrimPrefix(pattern, "*.")
|
||||
|
||||
if strings.EqualFold(domain, baseDomain) || strings.HasSuffix(domain, "."+baseDomain) {
|
||||
score = len(baseDomain)
|
||||
matches = append(matches, entry)
|
||||
}
|
||||
case domain == pattern:
|
||||
score = math.MaxInt
|
||||
matches = append(matches, entry)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
selectedResId = entry.ResID
|
||||
}
|
||||
}
|
||||
|
||||
return selectedResId, matches
|
||||
}
|
||||
|
||||
// filterDomains returns a list of normalized domains
|
||||
func filterDomains(domains []string) []string {
|
||||
newDomains := make([]string, 0, len(domains))
|
||||
for _, d := range domains {
|
||||
if d == "" {
|
||||
func filterDomains(entries []*ForwarderEntry) domain.List {
|
||||
newDomains := make(domain.List, 0, len(entries))
|
||||
for _, d := range entries {
|
||||
if d.Domain == "" {
|
||||
log.Warn("empty domain in DNS forwarder")
|
||||
continue
|
||||
}
|
||||
newDomains = append(newDomains, nbdns.NormalizeZone(d))
|
||||
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
|
||||
}
|
||||
return newDomains
|
||||
}
|
||||
|
||||
103
client/internal/dnsfwd/forwarder_test.go
Normal file
103
client/internal/dnsfwd/forwarder_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func Test_getMatchingEntries(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
storedMappings map[string]route.ResID // key: domain pattern, value: resId
|
||||
queryDomain string
|
||||
expectedResId route.ResID
|
||||
}{
|
||||
{
|
||||
name: "Empty map returns empty string",
|
||||
storedMappings: map[string]route.ResID{},
|
||||
queryDomain: "example.com",
|
||||
expectedResId: "",
|
||||
},
|
||||
{
|
||||
name: "Exact match returns stored resId",
|
||||
storedMappings: map[string]route.ResID{"example.com": "res1"},
|
||||
queryDomain: "example.com",
|
||||
expectedResId: "res1",
|
||||
},
|
||||
{
|
||||
name: "Wildcard pattern matches base domain",
|
||||
storedMappings: map[string]route.ResID{"*.example.com": "res2"},
|
||||
queryDomain: "example.com",
|
||||
expectedResId: "res2",
|
||||
},
|
||||
{
|
||||
name: "Wildcard pattern matches subdomain",
|
||||
storedMappings: map[string]route.ResID{"*.example.com": "res3"},
|
||||
queryDomain: "foo.example.com",
|
||||
expectedResId: "res3",
|
||||
},
|
||||
{
|
||||
name: "Wildcard pattern does not match different domain",
|
||||
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
|
||||
queryDomain: "foo.notexample.com",
|
||||
expectedResId: "",
|
||||
},
|
||||
{
|
||||
name: "Non-wildcard pattern does not match subdomain",
|
||||
storedMappings: map[string]route.ResID{"example.com": "res5"},
|
||||
queryDomain: "foo.example.com",
|
||||
expectedResId: "",
|
||||
},
|
||||
{
|
||||
name: "Exact match over overlapping wildcard",
|
||||
storedMappings: map[string]route.ResID{
|
||||
"*.example.com": "resWildcard",
|
||||
"foo.example.com": "resExact",
|
||||
},
|
||||
queryDomain: "foo.example.com",
|
||||
expectedResId: "resExact",
|
||||
},
|
||||
{
|
||||
name: "Overlapping wildcards: Select more specific wildcard",
|
||||
storedMappings: map[string]route.ResID{
|
||||
"*.example.com": "resA",
|
||||
"*.sub.example.com": "resB",
|
||||
},
|
||||
queryDomain: "bar.sub.example.com",
|
||||
expectedResId: "resB",
|
||||
},
|
||||
{
|
||||
name: "Wildcard multi-level subdomain match",
|
||||
storedMappings: map[string]route.ResID{
|
||||
"*.example.com": "resMulti",
|
||||
},
|
||||
queryDomain: "a.b.example.com",
|
||||
expectedResId: "resMulti",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
fwd := &DNSForwarder{}
|
||||
|
||||
var entries []*ForwarderEntry
|
||||
for domainPattern, resId := range tc.storedMappings {
|
||||
d, err := domain.FromString(domainPattern)
|
||||
require.NoError(t, err)
|
||||
entries = append(entries, &ForwarderEntry{
|
||||
Domain: d,
|
||||
ResID: resId,
|
||||
})
|
||||
}
|
||||
fwd.UpdateDomains(entries)
|
||||
|
||||
got, _ := fwd.getMatchingEntries(tc.queryDomain)
|
||||
assert.Equal(t, got, tc.expectedResId)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,9 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -18,20 +21,29 @@ const (
|
||||
dnsTTL = 60 //seconds
|
||||
)
|
||||
|
||||
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
||||
type ForwarderEntry struct {
|
||||
Domain domain.Domain
|
||||
ResID route.ResID
|
||||
Set firewall.Set
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
firewall firewall.Manager
|
||||
firewall firewall.Manager
|
||||
statusRecorder *peer.Status
|
||||
|
||||
fwRules []firewall.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager) *Manager {
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Start(domains []string) error {
|
||||
func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
log.Infof("starting DNS forwarder")
|
||||
if m.dnsForwarder != nil {
|
||||
return nil
|
||||
@@ -41,9 +53,9 @@ func (m *Manager) Start(domains []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL)
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder)
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(domains); err != nil {
|
||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
||||
}
|
||||
@@ -52,12 +64,12 @@ func (m *Manager) Start(domains []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateDomains(domains []string) {
|
||||
func (m *Manager) UpdateDomains(entries []*ForwarderEntry) {
|
||||
if m.dnsForwarder == nil {
|
||||
return
|
||||
}
|
||||
|
||||
m.dnsForwarder.UpdateDomains(domains)
|
||||
m.dnsForwarder.UpdateDomains(entries)
|
||||
}
|
||||
|
||||
func (m *Manager) Stop(ctx context.Context) error {
|
||||
@@ -78,34 +90,34 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
return nberrors.FormatErrorOrNil(mErr)
|
||||
}
|
||||
|
||||
func (h *Manager) allowDNSFirewall() error {
|
||||
func (m *Manager) allowDNSFirewall() error {
|
||||
dport := &firewall.Port{
|
||||
IsRange: false,
|
||||
Values: []uint16{ListenPort},
|
||||
}
|
||||
|
||||
if h.firewall == nil {
|
||||
if m.firewall == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
||||
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
}
|
||||
h.fwRules = dnsRules
|
||||
m.fwRules = dnsRules
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Manager) dropDNSFirewall() error {
|
||||
func (m *Manager) dropDNSFirewall() error {
|
||||
var mErr *multierror.Error
|
||||
for _, rule := range h.fwRules {
|
||||
if err := h.firewall.DeletePeerRule(rule); err != nil {
|
||||
for _, rule := range m.fwRules {
|
||||
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
h.fwRules = nil
|
||||
m.fwRules = nil
|
||||
return nberrors.FormatErrorOrNil(mErr)
|
||||
}
|
||||
|
||||
@@ -353,7 +353,7 @@ func (e *Engine) Start() error {
|
||||
|
||||
// start flow manager right after interface creation
|
||||
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
|
||||
e.flowManager = netflow.NewManager(e.wgInterface, publicKey[:], e.statusRecorder)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
@@ -527,7 +527,7 @@ func (e *Engine) blockLanAccess() {
|
||||
if _, err := e.firewall.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{v4},
|
||||
network,
|
||||
firewallManager.Network{Prefix: network},
|
||||
firewallManager.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
@@ -952,11 +952,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Apply ACLs in the beginning to avoid security leaks
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap)
|
||||
}
|
||||
|
||||
if e.firewall != nil {
|
||||
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
||||
if err := localipfw.UpdateLocalIPs(); err != nil {
|
||||
@@ -965,16 +960,21 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
}
|
||||
|
||||
// DNS forwarder
|
||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains)
|
||||
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
if e.acl != nil {
|
||||
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
|
||||
}
|
||||
|
||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
|
||||
// Ingress forward rules
|
||||
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
|
||||
log.Errorf("failed to update forward rules, err: %v", err)
|
||||
@@ -1079,21 +1079,24 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
||||
return routes
|
||||
}
|
||||
|
||||
func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
|
||||
var dnsRoutes []string
|
||||
for _, protoRoute := range protoRoutes {
|
||||
if len(protoRoute.Domains) == 0 {
|
||||
func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderEntry {
|
||||
var entries []*dnsfwd.ForwarderEntry
|
||||
for _, route := range routes {
|
||||
if len(route.Domains) == 0 {
|
||||
continue
|
||||
}
|
||||
if protoRoute.Peer == myPubKey {
|
||||
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
|
||||
if route.Peer == myPubKey {
|
||||
domainSet := firewallManager.NewDomainSet(route.Domains)
|
||||
for _, d := range route.Domains {
|
||||
entries = append(entries, &dnsfwd.ForwarderEntry{
|
||||
Domain: d,
|
||||
Set: domainSet,
|
||||
ResID: route.GetResourceID(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return dnsRoutes
|
||||
return entries
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
|
||||
@@ -1223,36 +1226,19 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
|
||||
PreSharedKey: e.config.PreSharedKey,
|
||||
}
|
||||
|
||||
if e.config.RosenpassEnabled && !e.config.RosenpassPermissive {
|
||||
lk := []byte(e.config.WgPrivateKey.PublicKey().String())
|
||||
rk := []byte(wgConfig.RemoteKey)
|
||||
var keyInput []byte
|
||||
if string(lk) > string(rk) {
|
||||
//nolint:gocritic
|
||||
keyInput = append(lk[:16], rk[:16]...)
|
||||
} else {
|
||||
//nolint:gocritic
|
||||
keyInput = append(rk[:16], lk[:16]...)
|
||||
}
|
||||
|
||||
key, err := wgtypes.NewKey(keyInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgConfig.PreSharedKey = &key
|
||||
}
|
||||
|
||||
// randomize connection timeout
|
||||
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
|
||||
config := peer.ConnConfig{
|
||||
Key: pubKey,
|
||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
Timeout: timeout,
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
RosenpassPubKey: e.getRosenpassPubKey(),
|
||||
RosenpassAddr: e.getRosenpassAddr(),
|
||||
Key: pubKey,
|
||||
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
Timeout: timeout,
|
||||
WgConfig: wgConfig,
|
||||
LocalWgPort: e.config.WgPort,
|
||||
RosenpassConfig: peer.RosenpassConfig{
|
||||
PubKey: e.getRosenpassPubKey(),
|
||||
Addr: e.getRosenpassAddr(),
|
||||
PermissiveMode: e.config.RosenpassPermissive,
|
||||
},
|
||||
ICEConfig: icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
@@ -1760,7 +1746,10 @@ func (e *Engine) GetWgAddr() net.IP {
|
||||
}
|
||||
|
||||
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
||||
func (e *Engine) updateDNSForwarder(
|
||||
enabled bool,
|
||||
fwdEntries []*dnsfwd.ForwarderEntry,
|
||||
) {
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
@@ -1771,18 +1760,18 @@ func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(domains) > 0 {
|
||||
log.Infof("enable domain router service for domains: %v", domains)
|
||||
if len(fwdEntries) > 0 {
|
||||
if e.dnsForwardMgr == nil {
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(domains); err != nil {
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
} else {
|
||||
log.Infof("update domain router service for domains: %v", domains)
|
||||
e.dnsForwardMgr.UpdateDomains(domains)
|
||||
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
||||
}
|
||||
} else if e.dnsForwardMgr != nil {
|
||||
log.Infof("disable domain router service")
|
||||
|
||||
@@ -49,6 +49,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -1400,15 +1401,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||
t.Helper()
|
||||
|
||||
config := &server.Config{
|
||||
Stuns: []*server.Host{},
|
||||
TURNConfig: &server.TURNConfig{},
|
||||
Relay: &server.Relay{
|
||||
config := &types.Config{
|
||||
Stuns: []*types.Host{},
|
||||
TURNConfig: &types.TURNConfig{},
|
||||
Relay: &types.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &server.Host{
|
||||
Signal: &types.Host{
|
||||
Proto: "http",
|
||||
URI: "localhost:10000",
|
||||
},
|
||||
@@ -1446,7 +1447,9 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/ti-mo/netfilter"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const defaultChannelSize = 100
|
||||
@@ -176,7 +177,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
|
||||
srcIP := flow.TupleOrig.IP.SourceAddress
|
||||
dstIP := flow.TupleOrig.IP.DestinationAddress
|
||||
|
||||
if !c.relevantFlow(srcIP, dstIP) {
|
||||
if !c.relevantFlow(flow.Mark, srcIP, dstIP) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -193,7 +194,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
|
||||
}
|
||||
|
||||
flowID := c.getFlowID(flow.ID)
|
||||
direction := c.inferDirection(srcIP, dstIP)
|
||||
direction := c.inferDirection(flow.Mark, srcIP, dstIP)
|
||||
|
||||
eventType := nftypes.TypeStart
|
||||
eventStr := "New"
|
||||
@@ -224,15 +225,14 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
|
||||
}
|
||||
|
||||
// relevantFlow checks if the flow is related to the specified interface
|
||||
func (c *ConnTrack) relevantFlow(srcIP, dstIP netip.Addr) bool {
|
||||
// TODO: filter traffic by interface
|
||||
|
||||
wgnet := c.iface.Address().Network
|
||||
if !wgnet.Contains(srcIP.AsSlice()) && !wgnet.Contains(dstIP.AsSlice()) {
|
||||
return false
|
||||
func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
|
||||
if nbnet.IsDataPlaneMark(mark) {
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
// fallback if mark rules are not in place
|
||||
wgnet := c.iface.Address().Network
|
||||
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice())
|
||||
}
|
||||
|
||||
// mapRxPackets maps packet counts to RX based on flow direction
|
||||
@@ -282,7 +282,15 @@ func (c *ConnTrack) getFlowID(conntrackID uint32) uuid.UUID {
|
||||
return uuid.NewSHA1(c.instanceID, buf[:])
|
||||
}
|
||||
|
||||
func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction {
|
||||
func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes.Direction {
|
||||
switch mark {
|
||||
case nbnet.DataPlaneMarkIn:
|
||||
return nftypes.Ingress
|
||||
case nbnet.DataPlaneMarkOut:
|
||||
return nftypes.Egress
|
||||
}
|
||||
|
||||
// fallback if marks are not set
|
||||
wgaddr := c.iface.Address().IP
|
||||
wgnetwork := c.iface.Address().Network
|
||||
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
|
||||
@@ -298,8 +306,6 @@ func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction {
|
||||
case wgnetwork.Contains(dst):
|
||||
// resource network -> netbird network
|
||||
return nftypes.Egress
|
||||
|
||||
// TODO: handle site2site traffic
|
||||
}
|
||||
|
||||
return nftypes.DirectionUnknown
|
||||
|
||||
@@ -19,11 +19,9 @@ import (
|
||||
type rcvChan chan *types.EventFields
|
||||
type Logger struct {
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
enabled atomic.Bool
|
||||
rcvChan atomic.Pointer[rcvChan]
|
||||
cancelReceiver context.CancelFunc
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgIfaceIPNet net.IPNet
|
||||
dnsCollection atomic.Bool
|
||||
@@ -31,12 +29,9 @@ type Logger struct {
|
||||
Store types.Store
|
||||
}
|
||||
|
||||
func New(ctx context.Context, statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Logger{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
statusRecorder: statusRecorder,
|
||||
wgIfaceIPNet: wgIfaceIPNet,
|
||||
Store: store.NewMemoryStore(),
|
||||
@@ -70,8 +65,8 @@ func (l *Logger) startReceiver() {
|
||||
}
|
||||
|
||||
l.mux.Lock()
|
||||
ctx, cancel := context.WithCancel(l.ctx)
|
||||
l.cancelReceiver = cancel
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
l.cancel = cancel
|
||||
l.mux.Unlock()
|
||||
|
||||
c := make(rcvChan, 100)
|
||||
@@ -91,25 +86,25 @@ func (l *Logger) startReceiver() {
|
||||
Timestamp: time.Now().UTC(),
|
||||
}
|
||||
|
||||
var isExitNode bool
|
||||
if event.Direction == types.Ingress {
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
|
||||
event.SourceResourceID, isExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
|
||||
}
|
||||
} else if event.Direction == types.Egress {
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
|
||||
event.DestResourceID, isExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
|
||||
}
|
||||
var isSrcExitNode bool
|
||||
var isDestExitNode bool
|
||||
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
|
||||
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
|
||||
}
|
||||
|
||||
if l.shouldStore(eventFields, isExitNode) {
|
||||
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
|
||||
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
|
||||
}
|
||||
|
||||
if l.shouldStore(eventFields, isSrcExitNode || isDestExitNode) {
|
||||
l.Store.StoreEvent(&event)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Disable() {
|
||||
func (l *Logger) Close() {
|
||||
l.stop()
|
||||
l.Store.Close()
|
||||
}
|
||||
@@ -121,9 +116,9 @@ func (l *Logger) stop() {
|
||||
|
||||
l.enabled.Store(false)
|
||||
l.mux.Lock()
|
||||
if l.cancelReceiver != nil {
|
||||
l.cancelReceiver()
|
||||
l.cancelReceiver = nil
|
||||
if l.cancel != nil {
|
||||
l.cancel()
|
||||
l.cancel = nil
|
||||
}
|
||||
l.rcvChan.Store(nil)
|
||||
l.mux.Unlock()
|
||||
@@ -142,11 +137,6 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
|
||||
l.exitNodeCollection.Store(exitNodeCollection)
|
||||
}
|
||||
|
||||
func (l *Logger) Close() {
|
||||
l.stop()
|
||||
l.cancel()
|
||||
}
|
||||
|
||||
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
||||
// check dns collection
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package logger_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -13,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
logger := logger.New(context.Background(), nil, net.IPNet{})
|
||||
logger := logger.New(nil, net.IPNet{})
|
||||
logger.Enable()
|
||||
|
||||
event := types.EventFields{
|
||||
@@ -40,7 +39,7 @@ func TestStore(t *testing.T) {
|
||||
}
|
||||
|
||||
// test disable
|
||||
logger.Disable()
|
||||
logger.Close()
|
||||
wait()
|
||||
logger.StoreEvent(event)
|
||||
wait()
|
||||
|
||||
@@ -27,18 +27,18 @@ type Manager struct {
|
||||
logger nftypes.FlowLogger
|
||||
flowConfig *nftypes.FlowConfig
|
||||
conntrack nftypes.ConnTracker
|
||||
ctx context.Context
|
||||
receiverClient *client.GRPCClient
|
||||
publicKey []byte
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewManager creates a new netflow manager
|
||||
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
var ipNet net.IPNet
|
||||
if iface != nil {
|
||||
ipNet = *iface.Address().Network
|
||||
}
|
||||
flowLogger := logger.New(ctx, statusRecorder, ipNet)
|
||||
flowLogger := logger.New(statusRecorder, ipNet)
|
||||
|
||||
var ct nftypes.ConnTracker
|
||||
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||
@@ -48,7 +48,6 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte
|
||||
return &Manager{
|
||||
logger: flowLogger,
|
||||
conntrack: ct,
|
||||
ctx: ctx,
|
||||
publicKey: publicKey,
|
||||
}
|
||||
}
|
||||
@@ -68,21 +67,9 @@ func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
|
||||
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||
// first make sender ready so events don't pile up
|
||||
if m.needsNewClient(previous) {
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("error closing previous flow client: %v", err)
|
||||
}
|
||||
if err := m.resetClient(); err != nil {
|
||||
return fmt.Errorf("reset client: %w", err)
|
||||
}
|
||||
|
||||
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create client: %w", err)
|
||||
}
|
||||
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||
|
||||
m.receiverClient = flowClient
|
||||
go m.receiveACKs(flowClient)
|
||||
go m.startSender()
|
||||
}
|
||||
|
||||
m.logger.Enable()
|
||||
@@ -96,17 +83,50 @@ func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) resetClient() error {
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("error closing previous flow client: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create client: %w", err)
|
||||
}
|
||||
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||
|
||||
m.receiverClient = flowClient
|
||||
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
|
||||
go m.receiveACKs(ctx, flowClient)
|
||||
go m.startSender(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// disableFlow stops components for flow tracking
|
||||
func (m *Manager) disableFlow() error {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Stop()
|
||||
}
|
||||
|
||||
m.logger.Disable()
|
||||
m.logger.Close()
|
||||
|
||||
if m.receiverClient != nil {
|
||||
return m.receiverClient.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -133,17 +153,18 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
||||
|
||||
m.logger.UpdateConfig(update.DNSCollection, update.ExitNodeCollection)
|
||||
|
||||
changed := previous != nil && update.Enabled != previous.Enabled
|
||||
if update.Enabled {
|
||||
log.Infof("netflow manager enabled; starting netflow manager")
|
||||
if changed {
|
||||
log.Infof("netflow manager enabled; starting netflow manager")
|
||||
}
|
||||
return m.enableFlow(previous)
|
||||
}
|
||||
|
||||
log.Infof("netflow manager disabled; stopping netflow manager")
|
||||
err := m.disableFlow()
|
||||
if err != nil {
|
||||
log.Errorf("failed to disable netflow manager: %v", err)
|
||||
if changed {
|
||||
log.Infof("netflow manager disabled; stopping netflow manager")
|
||||
}
|
||||
return err
|
||||
return m.disableFlow()
|
||||
}
|
||||
|
||||
// Close cleans up all resources
|
||||
@@ -151,17 +172,9 @@ func (m *Manager) Close() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Close()
|
||||
if err := m.disableFlow(); err != nil {
|
||||
log.Warnf("failed to disable flow manager: %v", err)
|
||||
}
|
||||
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("failed to close receiver client: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Close()
|
||||
}
|
||||
|
||||
// GetLogger returns the flow logger
|
||||
@@ -169,13 +182,13 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
|
||||
return m.logger
|
||||
}
|
||||
|
||||
func (m *Manager) startSender() {
|
||||
func (m *Manager) startSender(ctx context.Context) {
|
||||
ticker := time.NewTicker(m.flowConfig.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
events := m.logger.GetEvents()
|
||||
@@ -190,8 +203,8 @@ func (m *Manager) startSender() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) receiveACKs(client *client.GRPCClient) {
|
||||
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||
func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
|
||||
err := client.Receive(ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||
id, err := uuid.FromBytes(ack.EventId)
|
||||
if err != nil {
|
||||
log.Warnf("failed to convert ack event id to uuid: %v", err)
|
||||
|
||||
200
client/internal/netflow/manager_test.go
Normal file
200
client/internal/netflow/manager_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package netflow
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type mockIFaceMapper struct {
|
||||
address wgaddr.Address
|
||||
isUserspaceBind bool
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Name() string {
|
||||
return "wt0"
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Address() wgaddr.Address {
|
||||
return m.address
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) IsUserspaceBind() bool {
|
||||
return m.isUserspaceBind
|
||||
}
|
||||
|
||||
func TestManager_Update(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
|
||||
publicKey := []byte("test-public-key")
|
||||
statusRecorder := peer.NewRecorder("")
|
||||
|
||||
manager := NewManager(mockIFace, publicKey, statusRecorder)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config *types.FlowConfig
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
},
|
||||
{
|
||||
name: "disabled config",
|
||||
config: &types.FlowConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enabled config with minimal valid settings",
|
||||
config: &types.FlowConfig{
|
||||
Enabled: true,
|
||||
URL: "https://example.com",
|
||||
TokenPayload: "test-payload",
|
||||
TokenSignature: "test-signature",
|
||||
Interval: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := manager.Update(tc.config)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
if tc.config == nil {
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, manager.flowConfig)
|
||||
|
||||
if tc.config.Enabled {
|
||||
assert.Equal(t, tc.config.Enabled, manager.flowConfig.Enabled)
|
||||
}
|
||||
|
||||
if tc.config.URL != "" {
|
||||
assert.Equal(t, tc.config.URL, manager.flowConfig.URL)
|
||||
}
|
||||
|
||||
if tc.config.TokenPayload != "" {
|
||||
assert.Equal(t, tc.config.TokenPayload, manager.flowConfig.TokenPayload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_TokenPreservation(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
|
||||
publicKey := []byte("test-public-key")
|
||||
manager := NewManager(mockIFace, publicKey, nil)
|
||||
|
||||
// First update with tokens
|
||||
initialConfig := &types.FlowConfig{
|
||||
Enabled: false,
|
||||
TokenPayload: "initial-payload",
|
||||
TokenSignature: "initial-signature",
|
||||
}
|
||||
|
||||
err := manager.Update(initialConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second update without tokens should preserve them
|
||||
updatedConfig := &types.FlowConfig{
|
||||
Enabled: false,
|
||||
URL: "https://example.com",
|
||||
}
|
||||
|
||||
err = manager.Update(updatedConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify tokens were preserved
|
||||
assert.Equal(t, "initial-payload", manager.flowConfig.TokenPayload)
|
||||
assert.Equal(t, "initial-signature", manager.flowConfig.TokenSignature)
|
||||
}
|
||||
|
||||
func TestManager_NeedsNewClient(t *testing.T) {
|
||||
manager := &Manager{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
previous *types.FlowConfig
|
||||
current *types.FlowConfig
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil previous config",
|
||||
previous: nil,
|
||||
current: &types.FlowConfig{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "previous disabled",
|
||||
previous: &types.FlowConfig{Enabled: false},
|
||||
current: &types.FlowConfig{Enabled: true},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different URL",
|
||||
previous: &types.FlowConfig{Enabled: true, URL: "old-url"},
|
||||
current: &types.FlowConfig{Enabled: true, URL: "new-url"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different TokenPayload",
|
||||
previous: &types.FlowConfig{Enabled: true, TokenPayload: "old-payload"},
|
||||
current: &types.FlowConfig{Enabled: true, TokenPayload: "new-payload"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different TokenSignature",
|
||||
previous: &types.FlowConfig{Enabled: true, TokenSignature: "old-signature"},
|
||||
current: &types.FlowConfig{Enabled: true, TokenSignature: "new-signature"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "same config",
|
||||
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
|
||||
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "only interval changed",
|
||||
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 30 * time.Second},
|
||||
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 60 * time.Second},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
manager.flowConfig = tc.current
|
||||
result := manager.needsNewClient(tc.previous)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
const ZoneID = 0x1BD0
|
||||
|
||||
type Protocol uint8
|
||||
|
||||
const (
|
||||
@@ -120,9 +122,6 @@ type FlowLogger interface {
|
||||
Close()
|
||||
// Enable enables the flow logger receiver
|
||||
Enable()
|
||||
// Disable disables the flow logger receiver
|
||||
Disable()
|
||||
|
||||
// UpdateConfig updates the flow manager configuration
|
||||
UpdateConfig(dnsCollection, exitNodeCollection bool)
|
||||
}
|
||||
|
||||
@@ -60,6 +60,15 @@ type WgConfig struct {
|
||||
PreSharedKey *wgtypes.Key
|
||||
}
|
||||
|
||||
type RosenpassConfig struct {
|
||||
// RosenpassPubKey is this peer's Rosenpass public key
|
||||
PubKey []byte
|
||||
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
|
||||
Addr string
|
||||
|
||||
PermissiveMode bool
|
||||
}
|
||||
|
||||
// ConnConfig is a peer Connection configuration
|
||||
type ConnConfig struct {
|
||||
// Key is a public key of a remote peer
|
||||
@@ -73,10 +82,7 @@ type ConnConfig struct {
|
||||
|
||||
LocalWgPort int
|
||||
|
||||
// RosenpassPubKey is this peer's Rosenpass public key
|
||||
RosenpassPubKey []byte
|
||||
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
|
||||
RosenpassAddr string
|
||||
RosenpassConfig RosenpassConfig
|
||||
|
||||
// ICEConfig ICE protocol configuration
|
||||
ICEConfig icemaker.Config
|
||||
@@ -103,11 +109,14 @@ type Conn struct {
|
||||
|
||||
workerICE *WorkerICE
|
||||
workerRelay *WorkerRelay
|
||||
wgWatcherWg sync.WaitGroup
|
||||
|
||||
connIDRelay nbnet.ConnectionID
|
||||
connIDICE nbnet.ConnectionID
|
||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
|
||||
rosenpassRemoteKey []byte
|
||||
|
||||
wgProxyICE wgproxy.Proxy
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
@@ -211,6 +220,7 @@ func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
|
||||
// Close closes this peer Conn issuing a close event to the Conn closeCh
|
||||
func (conn *Conn) Close() {
|
||||
conn.mu.Lock()
|
||||
defer conn.wgWatcherWg.Wait()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
conn.log.Infof("close peer connection")
|
||||
@@ -252,6 +262,7 @@ func (conn *Conn) Close() {
|
||||
}
|
||||
|
||||
conn.setStatusToDisconnected()
|
||||
conn.log.Infof("peer connection has been closed")
|
||||
}
|
||||
|
||||
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
||||
@@ -362,6 +373,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
}
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
// todo consider to run conn.wgWatcherWg.Wait() here
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.wgProxyRelay.Pause()
|
||||
@@ -371,7 +383,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
||||
wgProxy.Work()
|
||||
}
|
||||
|
||||
if err = conn.configureWGEndpoint(ep); err != nil {
|
||||
if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
|
||||
conn.handleConfigurationFailure(err, wgProxy)
|
||||
return
|
||||
}
|
||||
@@ -404,10 +416,15 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
conn.dumpState.SwitchToRelay()
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
|
||||
conn.log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
} else {
|
||||
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
|
||||
@@ -469,16 +486,22 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = connPriorityRelay
|
||||
conn.statusRelay.Set(StatusConnected)
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
@@ -502,6 +525,7 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
conn.currentConnPriority = connPriorityNone
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
@@ -541,13 +565,14 @@ func (conn *Conn) listenGuardEvent(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
|
||||
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error {
|
||||
presharedKey := conn.presharedKey(remoteRPKey)
|
||||
return conn.config.WgConfig.WgInterface.UpdatePeer(
|
||||
conn.config.WgConfig.RemoteKey,
|
||||
conn.config.WgConfig.AllowedIps,
|
||||
defaultWgKeepAlive,
|
||||
addr,
|
||||
conn.config.WgConfig.PreSharedKey,
|
||||
presharedKey,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -768,6 +793,44 @@ func (conn *Conn) AllowedIP() netip.Addr {
|
||||
return conn.config.WgConfig.AllowedIps[0].Addr()
|
||||
}
|
||||
|
||||
func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
if conn.config.RosenpassConfig.PubKey == nil {
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
}
|
||||
|
||||
if remoteRosenpassKey == nil && conn.config.RosenpassConfig.PermissiveMode {
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
}
|
||||
|
||||
determKey, err := conn.rosenpassDetermKey()
|
||||
if err != nil {
|
||||
conn.log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||
return conn.config.WgConfig.PreSharedKey
|
||||
}
|
||||
|
||||
return determKey
|
||||
}
|
||||
|
||||
// todo: move this logic into Rosenpass package
|
||||
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
|
||||
lk := []byte(conn.config.LocalKey)
|
||||
rk := []byte(conn.config.Key) // remote key
|
||||
var keyInput []byte
|
||||
if string(lk) > string(rk) {
|
||||
//nolint:gocritic
|
||||
keyInput = append(lk[:16], rk[:16]...)
|
||||
} else {
|
||||
//nolint:gocritic
|
||||
keyInput = append(rk[:16], lk[:16]...)
|
||||
}
|
||||
|
||||
key, err := wgtypes.NewKey(keyInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
return config.LocalKey > config.Key
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -161,3 +162,145 @@ func TestConn_Status(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_presharedKey(t *testing.T) {
|
||||
conn1 := Conn{
|
||||
config: ConnConfig{
|
||||
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
RosenpassConfig: RosenpassConfig{},
|
||||
},
|
||||
}
|
||||
conn2 := Conn{
|
||||
config: ConnConfig{
|
||||
Key: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
LocalKey: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
|
||||
RosenpassConfig: RosenpassConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
conn1Permissive bool
|
||||
conn1RosenpassEnabled bool
|
||||
conn2Permissive bool
|
||||
conn2RosenpassEnabled bool
|
||||
conn1ExpectedInitialKey bool
|
||||
conn2ExpectedInitialKey bool
|
||||
}{
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: false,
|
||||
conn2Permissive: false,
|
||||
conn2RosenpassEnabled: false,
|
||||
conn1ExpectedInitialKey: false,
|
||||
conn2ExpectedInitialKey: false,
|
||||
},
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: true,
|
||||
conn2Permissive: false,
|
||||
conn2RosenpassEnabled: true,
|
||||
conn1ExpectedInitialKey: true,
|
||||
conn2ExpectedInitialKey: true,
|
||||
},
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: true,
|
||||
conn2Permissive: false,
|
||||
conn2RosenpassEnabled: false,
|
||||
conn1ExpectedInitialKey: true,
|
||||
conn2ExpectedInitialKey: false,
|
||||
},
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: false,
|
||||
conn2Permissive: false,
|
||||
conn2RosenpassEnabled: true,
|
||||
conn1ExpectedInitialKey: false,
|
||||
conn2ExpectedInitialKey: true,
|
||||
},
|
||||
{
|
||||
conn1Permissive: true,
|
||||
conn1RosenpassEnabled: true,
|
||||
conn2Permissive: false,
|
||||
conn2RosenpassEnabled: false,
|
||||
conn1ExpectedInitialKey: false,
|
||||
conn2ExpectedInitialKey: false,
|
||||
},
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: false,
|
||||
conn2Permissive: true,
|
||||
conn2RosenpassEnabled: true,
|
||||
conn1ExpectedInitialKey: false,
|
||||
conn2ExpectedInitialKey: false,
|
||||
},
|
||||
{
|
||||
conn1Permissive: true,
|
||||
conn1RosenpassEnabled: true,
|
||||
conn2Permissive: true,
|
||||
conn2RosenpassEnabled: true,
|
||||
conn1ExpectedInitialKey: true,
|
||||
conn2ExpectedInitialKey: true,
|
||||
},
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: false,
|
||||
conn2Permissive: false,
|
||||
conn2RosenpassEnabled: true,
|
||||
conn1ExpectedInitialKey: false,
|
||||
conn2ExpectedInitialKey: true,
|
||||
},
|
||||
{
|
||||
conn1Permissive: false,
|
||||
conn1RosenpassEnabled: true,
|
||||
conn2Permissive: true,
|
||||
conn2RosenpassEnabled: true,
|
||||
conn1ExpectedInitialKey: true,
|
||||
conn2ExpectedInitialKey: true,
|
||||
},
|
||||
}
|
||||
|
||||
conn1.config.RosenpassConfig.PermissiveMode = true
|
||||
for i, test := range tests {
|
||||
tcase := i + 1
|
||||
t.Run(fmt.Sprintf("Rosenpass test case %d", tcase), func(t *testing.T) {
|
||||
conn1.config.RosenpassConfig = RosenpassConfig{}
|
||||
conn2.config.RosenpassConfig = RosenpassConfig{}
|
||||
|
||||
if test.conn1RosenpassEnabled {
|
||||
conn1.config.RosenpassConfig.PubKey = []byte("dummykey")
|
||||
}
|
||||
conn1.config.RosenpassConfig.PermissiveMode = test.conn1Permissive
|
||||
|
||||
if test.conn2RosenpassEnabled {
|
||||
conn2.config.RosenpassConfig.PubKey = []byte("dummykey")
|
||||
}
|
||||
conn2.config.RosenpassConfig.PermissiveMode = test.conn2Permissive
|
||||
|
||||
conn1PresharedKey := conn1.presharedKey(conn2.config.RosenpassConfig.PubKey)
|
||||
conn2PresharedKey := conn2.presharedKey(conn1.config.RosenpassConfig.PubKey)
|
||||
|
||||
if test.conn1ExpectedInitialKey {
|
||||
if conn1PresharedKey == nil {
|
||||
t.Errorf("Case %d: Expected conn1 to have a non-nil key, but got nil", tcase)
|
||||
}
|
||||
} else {
|
||||
if conn1PresharedKey != nil {
|
||||
t.Errorf("Case %d: Expected conn1 to have a nil key, but got %v", tcase, conn1PresharedKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Assert conn2's key expectation
|
||||
if test.conn2ExpectedInitialKey {
|
||||
if conn2PresharedKey == nil {
|
||||
t.Errorf("Case %d: Expected conn2 to have a non-nil key, but got nil", tcase)
|
||||
}
|
||||
} else {
|
||||
if conn2PresharedKey != nil {
|
||||
t.Errorf("Case %d: Expected conn2 to have a nil key, but got %v", tcase, conn2PresharedKey)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,8 +154,8 @@ func (h *Handshaker) sendOffer() error {
|
||||
IceCredentials: IceCredentials{iceUFrag, icePwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassPubKey,
|
||||
RosenpassAddr: h.config.RosenpassAddr,
|
||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||
}
|
||||
|
||||
addr, err := h.relay.RelayInstanceAddress()
|
||||
@@ -174,8 +174,8 @@ func (h *Handshaker) sendAnswer() error {
|
||||
IceCredentials: IceCredentials{uFrag, pwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassPubKey,
|
||||
RosenpassAddr: h.config.RosenpassAddr,
|
||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||
}
|
||||
addr, err := h.relay.RelayInstanceAddress()
|
||||
if err == nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pion/ice/v3"
|
||||
"github.com/pion/logging"
|
||||
"github.com/pion/randutil"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -35,6 +36,10 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
fac := logging.NewDefaultLoggerFactory()
|
||||
|
||||
//fac.Writer = log.StandardLogger().Writer()
|
||||
|
||||
agentConfig := &ice.AgentConfig{
|
||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||
@@ -51,6 +56,7 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida
|
||||
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
|
||||
LocalUfrag: ufrag,
|
||||
LocalPwd: pwd,
|
||||
LoggerFactory: fac,
|
||||
}
|
||||
|
||||
if config.DisableIPv6Discovery {
|
||||
|
||||
@@ -2,40 +2,94 @@ package peer
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// routeEntry holds the route prefix and the corresponding resource ID.
|
||||
type routeEntry struct {
|
||||
prefix netip.Prefix
|
||||
resourceID route.ResID
|
||||
}
|
||||
|
||||
type routeIDLookup struct {
|
||||
localMap sync.Map
|
||||
remoteMap sync.Map
|
||||
localRoutes []routeEntry
|
||||
localLock sync.RWMutex
|
||||
|
||||
remoteRoutes []routeEntry
|
||||
remoteLock sync.RWMutex
|
||||
|
||||
resolvedIPs sync.Map
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) {
|
||||
_, exists := r.localMap.LoadOrStore(route, resourceID)
|
||||
if exists {
|
||||
log.Tracef("resourceID %s already exists in local map", resourceID)
|
||||
func (r *routeIDLookup) AddLocalRouteID(resourceID route.ResID, route netip.Prefix) {
|
||||
r.localLock.Lock()
|
||||
defer r.localLock.Unlock()
|
||||
|
||||
// update the resource id if the route already exists.
|
||||
for i, entry := range r.localRoutes {
|
||||
if entry.prefix == route {
|
||||
r.localRoutes[i].resourceID = resourceID
|
||||
log.Tracef("resourceID for route %v updated to %s in local routes", route, resourceID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// append and sort descending by prefix bits (more specific first)
|
||||
r.localRoutes = append(r.localRoutes, routeEntry{prefix: route, resourceID: resourceID})
|
||||
sort.Slice(r.localRoutes, func(i, j int) bool {
|
||||
return r.localRoutes[i].prefix.Bits() > r.localRoutes[j].prefix.Bits()
|
||||
})
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) {
|
||||
r.localMap.Delete(route)
|
||||
}
|
||||
r.localLock.Lock()
|
||||
defer r.localLock.Unlock()
|
||||
|
||||
func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) {
|
||||
_, exists := r.remoteMap.LoadOrStore(route, resourceID)
|
||||
if exists {
|
||||
log.Tracef("resourceID %s already exists in remote map", resourceID)
|
||||
for i, entry := range r.localRoutes {
|
||||
if entry.prefix == route {
|
||||
r.localRoutes = append(r.localRoutes[:i], r.localRoutes[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) {
|
||||
r.remoteMap.Delete(route)
|
||||
func (r *routeIDLookup) AddRemoteRouteID(resourceID route.ResID, route netip.Prefix) {
|
||||
r.remoteLock.Lock()
|
||||
defer r.remoteLock.Unlock()
|
||||
|
||||
for i, entry := range r.remoteRoutes {
|
||||
if entry.prefix == route {
|
||||
r.remoteRoutes[i].resourceID = resourceID
|
||||
log.Tracef("resourceID for route %v updated to %s in remote routes", route, resourceID)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// append and sort descending by prefix bits.
|
||||
r.remoteRoutes = append(r.remoteRoutes, routeEntry{prefix: route, resourceID: resourceID})
|
||||
sort.Slice(r.remoteRoutes, func(i, j int) bool {
|
||||
return r.remoteRoutes[i].prefix.Bits() > r.remoteRoutes[j].prefix.Bits()
|
||||
})
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) {
|
||||
func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) {
|
||||
r.remoteLock.Lock()
|
||||
defer r.remoteLock.Unlock()
|
||||
|
||||
for i, entry := range r.remoteRoutes {
|
||||
if entry.prefix == route {
|
||||
r.remoteRoutes = append(r.remoteRoutes[:i], r.remoteRoutes[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *routeIDLookup) AddResolvedIP(resourceID route.ResID, route netip.Prefix) {
|
||||
r.resolvedIPs.Store(route.Addr(), resourceID)
|
||||
}
|
||||
|
||||
@@ -44,37 +98,35 @@ func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) {
|
||||
}
|
||||
|
||||
// Lookup returns the resource ID for the given IP address
|
||||
// and a bool indicating if the IP is an exit node
|
||||
func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) {
|
||||
var isExitNode bool
|
||||
|
||||
resId, ok := r.resolvedIPs.Load(ip)
|
||||
if ok {
|
||||
return resId.(string), false
|
||||
// and a bool indicating if the IP is an exit node.
|
||||
func (r *routeIDLookup) Lookup(ip netip.Addr) (route.ResID, bool) {
|
||||
if res, ok := r.resolvedIPs.Load(ip); ok {
|
||||
return res.(route.ResID), false
|
||||
}
|
||||
|
||||
var resourceID string
|
||||
r.localMap.Range(func(key, value interface{}) bool {
|
||||
pref := key.(netip.Prefix)
|
||||
if pref.Contains(ip) {
|
||||
resourceID = value.(string)
|
||||
isExitNode = pref.Bits() == 0
|
||||
return false
|
||||
var resourceID route.ResID
|
||||
var isExitNode bool
|
||||
|
||||
r.localLock.RLock()
|
||||
for _, entry := range r.localRoutes {
|
||||
if entry.prefix.Contains(ip) {
|
||||
resourceID = entry.resourceID
|
||||
isExitNode = entry.prefix.Bits() == 0
|
||||
break
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
r.localLock.RUnlock()
|
||||
|
||||
if resourceID == "" {
|
||||
r.remoteMap.Range(func(key, value interface{}) bool {
|
||||
pref := key.(netip.Prefix)
|
||||
if pref.Contains(ip) {
|
||||
resourceID = value.(string)
|
||||
isExitNode = pref.Bits() == 0
|
||||
return false
|
||||
r.remoteLock.RLock()
|
||||
for _, entry := range r.remoteRoutes {
|
||||
if entry.prefix.Contains(ip) {
|
||||
resourceID = entry.resourceID
|
||||
isExitNode = entry.prefix.Bits() == 0
|
||||
break
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
r.remoteLock.RUnlock()
|
||||
}
|
||||
|
||||
return resourceID, isExitNode
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const eventQueueSize = 10
|
||||
@@ -313,7 +314,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string) error {
|
||||
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
@@ -581,14 +582,13 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||
}
|
||||
|
||||
// AddLocalPeerStateRoute adds a route to the local peer state
|
||||
func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
|
||||
func (d *Status) AddLocalPeerStateRoute(route string, resourceId route.ResID) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||
return
|
||||
if err == nil {
|
||||
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
|
||||
}
|
||||
|
||||
if d.localPeer.Routes == nil {
|
||||
@@ -596,8 +596,6 @@ func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
|
||||
}
|
||||
|
||||
d.localPeer.Routes[route] = struct{}{}
|
||||
|
||||
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
|
||||
}
|
||||
|
||||
// RemoveLocalPeerStateRoute removes a route from the local peer state
|
||||
@@ -606,14 +604,30 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) {
|
||||
defer d.mux.Unlock()
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse prefix %s: %v", route, err)
|
||||
return
|
||||
if err == nil {
|
||||
d.routeIDLookup.RemoveLocalRouteID(pref)
|
||||
}
|
||||
|
||||
delete(d.localPeer.Routes, route)
|
||||
}
|
||||
|
||||
d.routeIDLookup.RemoveLocalRouteID(pref)
|
||||
// AddResolvedIPLookupEntry adds a resolved IP lookup entry
|
||||
func (d *Status) AddResolvedIPLookupEntry(prefix netip.Prefix, resourceId route.ResID) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
d.routeIDLookup.AddResolvedIP(resourceId, prefix)
|
||||
}
|
||||
|
||||
// RemoveResolvedIPLookupEntry removes a resolved IP lookup entry
|
||||
func (d *Status) RemoveResolvedIPLookupEntry(route string) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
pref, err := netip.ParsePrefix(route)
|
||||
if err == nil {
|
||||
d.routeIDLookup.RemoveResolvedIP(pref)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanLocalPeerStateRoutes cleans all routes from the local peer state
|
||||
@@ -707,7 +721,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
|
||||
d.nsGroupStates = dnsStates
|
||||
}
|
||||
|
||||
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId string) {
|
||||
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId route.ResID) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ type WGWatcher struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctxLock sync.Mutex
|
||||
waitGroup sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||
@@ -48,24 +47,24 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
w.ctxLock.Lock()
|
||||
defer w.ctxLock.Unlock()
|
||||
|
||||
if w.ctx != nil && w.ctx.Err() == nil {
|
||||
w.log.Errorf("WireGuard watcher already enabled")
|
||||
w.ctxLock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(parentCtx)
|
||||
w.ctx = ctx
|
||||
w.ctxCancel = ctxCancel
|
||||
w.ctxLock.Unlock()
|
||||
|
||||
initialHandshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to read initial wg stats: %v", err)
|
||||
}
|
||||
|
||||
w.waitGroup.Add(1)
|
||||
go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
|
||||
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
|
||||
}
|
||||
|
||||
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
|
||||
@@ -81,13 +80,11 @@ func (w *WGWatcher) DisableWgWatcher() {
|
||||
|
||||
w.ctxCancel()
|
||||
w.ctxCancel = nil
|
||||
w.waitGroup.Wait()
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
|
||||
w.log.Infof("WireGuard watcher started")
|
||||
defer w.waitGroup.Done()
|
||||
|
||||
timer := time.NewTimer(wgHandshakeOvertime)
|
||||
defer timer.Stop()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user