Compare commits

...

79 Commits

Author SHA1 Message Date
bcmmbaga
24053750d9 add debug logs for user state change
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-04-30 15:21:01 +03:00
Zoltan Papp
d5081cef90 [client] Revert mgm client error handling (#3764) 2025-04-30 13:09:00 +02:00
Bethuel Mmbaga
488e619ec7 [management] Add network traffic events pagination (#3580)
* Add network traffic events pagination schema
2025-04-30 11:51:40 +03:00
hakansa
d2b42c8f68 [client] Add macOS .pkg installer support to installation script (#3755)
[client] Add macOS .pkg installer support to installation script
2025-04-29 13:43:42 +03:00
Maycon Santos
2f44fe2e23 [client] Feature/upload bundle (#3734)
Add an upload bundle option with the flag --upload-bundle; by default, the upload will use a NetBird address, which can be replaced using the flag --upload-bundle-url.

The upload server is available under the /upload-server path. The release change will push a docker image to netbirdio/upload image repository.

The server supports using s3 with pre-signed URL for direct upload and local file for storing bundles.
2025-04-29 00:43:50 +02:00
Bethuel Mmbaga
d8dc107bee [management] Skip IdP cache warm-up on Redis if data exists (#3733)
* Add Redis cache check to skip warm-up on startup if cache is already populated
* Refactor Redis test container setup for reusability
2025-04-28 15:10:40 +03:00
Viktor Liu
3fa915e271 [misc] Exclude client benchmarks from CI (#3752) 2025-04-28 13:40:36 +02:00
Pedro Maia Costa
47c3afe561 [management] add missing network admin mapping (#3751) 2025-04-28 11:05:27 +01:00
hakansa
84bfecdd37 [client] add byte counters & ruleID for routed traffic on userspace (#3653)
* [client] add byte counters for routed traffic on userspace 
* [client] add allowed ruleID for routed traffic on userspace
2025-04-28 10:10:41 +03:00
Viktor Liu
3cf87b6846 [client] Run container tests more generically (#3737) 2025-04-25 18:50:44 +02:00
Maycon Santos
4fe4c2054d [client] Move static check when running on foreground (#3742) 2025-04-25 18:25:48 +02:00
Pascal Fischer
38ada44a0e [management] allow impersonation via pats (#3739) 2025-04-25 16:40:54 +02:00
Pedro Maia Costa
dbf81a145e [management] network admin role (#3720) 2025-04-25 15:14:32 +01:00
Pedro Maia Costa
39483f8ca8 [management] Auditor role (#3721) 2025-04-25 15:04:25 +01:00
Carlos Hernandez
c0eaea938e [client] Fix macos privacy warning when checking static info (#3496)
avoid checking static info with a init call
2025-04-25 14:41:57 +02:00
Viktor Liu
ef8b8a2891 [client] Ensure dst-type local marks can overwrite nat marks (#3738) 2025-04-25 12:43:20 +02:00
Zoltan Papp
2817f62c13 [client] Fix error handling case of flow grpc error (#3727)
When a gRPC error occurs in the Flow package, it will be propagated to the upper layers and handled similarly to a Management gRPC error.

Always report a disconnected state in the event of any error
Hide the underlying gRPC errors
Force close the gRPC connection in the event of any error
2025-04-25 09:26:18 +02:00
Viktor Liu
4a9049566a [client] Set up firewall rules for dns routes dynamically based on dns response (#3702) 2025-04-24 17:37:28 +02:00
Viktor Liu
85f92f8321 [client] Add more userspace filter ACL test cases (#3730) 2025-04-24 12:57:46 +02:00
Viktor Liu
714beb6e3b [client] Fix exit node deselection (#3722) 2025-04-24 12:36:05 +02:00
Viktor Liu
400b9fca32 [management] Add firewall rule route ID and missing route domains (#3700) 2025-04-23 21:29:46 +02:00
hakansa
4013298e22 [client/ui] add connecting state to status handling (#3712) 2025-04-23 21:04:38 +02:00
Pascal Fischer
312bfd9bd7 [management] support custom domains per account (#3726) 2025-04-23 19:36:53 +02:00
Pascal Fischer
8db05838ca [misc] Change github runner for docker test (#3707) 2025-04-23 19:35:26 +02:00
Misha Bragin
c69df13515 [management] Add account meta (#3724) 2025-04-23 18:44:22 +02:00
Pascal Fischer
986eb8c1e0 [management] fix lastLogin on dashboard (#3725) 2025-04-23 15:54:49 +02:00
dependabot[bot]
197761ba4d Bump github.com/redis/go-redis/v9 from 9.7.1 to 9.7.3 (#3553)
Bumps [github.com/redis/go-redis/v9](https://github.com/redis/go-redis) from 9.7.1 to 9.7.3.
- [Release notes](https://github.com/redis/go-redis/releases)
- [Changelog](https://github.com/redis/go-redis/blob/master/CHANGELOG.md)
- [Commits](https://github.com/redis/go-redis/compare/v9.7.1...v9.7.3)

---
updated-dependencies:
- dependency-name: github.com/redis/go-redis/v9
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-23 10:21:36 +02:00
dependabot[bot]
f74ea64c7b Bump golang.org/x/net from 0.36.0 to 0.38.0 (#3695)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.36.0 to 0.38.0.
- [Commits](https://github.com/golang/net/compare/v0.36.0...v0.38.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-version: 0.38.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-23 10:20:51 +02:00
Viktor Liu
3b7b9d25bc [client] Keep new routes selected unless all are deselected (#3692) 2025-04-23 01:07:04 +02:00
Pascal Fischer
1a6d6b3109 [management] fix github run id (#3705) 2025-04-18 11:21:54 +02:00
Pascal Fischer
f686615876 [management] benchmarks use ref_name instead (#3704) 2025-04-17 21:57:54 +02:00
Pascal Fischer
a4311f574d [management] push benchmark results to grafana (#3701) 2025-04-17 21:01:23 +02:00
Pierre Timmermans
0bb8eae903 [docs] fix: broken link in the README file (#3697)
improve README.md, broken link for activity logging
2025-04-17 14:48:10 +02:00
Pascal Fischer
e0b33d325d [management] permissions manager use crud operations (#3690) 2025-04-16 17:25:03 +02:00
Zoltan Papp
c38e07d89a [client] Fix Rosenpass permissive mode handling (#3689)
fixes the Rosenpass preshared key handling to enable successful WireGuard handshakes when one side is in permissive mode. Key changes include:

Updating field accesses from RosenpassPubKey/RosenpassAddr to RosenpassConfig.PubKey/RosenpassConfig.Addr.
Modifying the preshared key computation logic to account for permissive mode.
Revising peer configuration in the Engine to use the new RosenpassConfig struct.
2025-04-16 16:04:43 +02:00
Lamera
a37368fff4 [misc] update gpt file permissions in install.sh (#3663)
* Fix install.sh for some installations

Fix install.sh for some installations by explicitly setting the file permissions

* Add sudo
2025-04-16 14:23:25 +02:00
Viktor Liu
0c93bd3d06 [client] Keep selecting new networks after first deselection (#3671) 2025-04-16 13:55:26 +02:00
Viktor Liu
a675531b5c [client] Set up signal to generate debug bundles (#3683) 2025-04-16 11:06:22 +02:00
hakansa
7cb366bc7d [client] Remove logrus writer assignment in pion logging (#3684) 2025-04-15 18:15:52 +03:00
Viktor Liu
a354004564 [client] Add remaining debug profiles (#3681) 2025-04-15 13:06:28 +02:00
Pedro Maia Costa
75bdd47dfb [management] get current user endpoint (#3666) 2025-04-15 11:06:07 +01:00
Viktor Liu
b165f63327 [client] Add heap profile to debug bundle (#3679) 2025-04-15 11:36:41 +02:00
hakansa
51bb52cdf5 [client] Refactor DNSForwarder to improve handle wildcard domain resource id matching (#3651)
[client] Refactor DNSForwarder to improve handle wildcard domain resource id matching (#3651)
2025-04-15 10:54:17 +03:00
Pedro Maia Costa
4134b857b4 [management] add permissions manager to geolocation handler (#3665) 2025-04-14 17:57:58 +01:00
Vlad
7839d2c169 [management] Refactor/management/updchannel (#3645)
* refactoring updatechannel - use read mutex for send update
2025-04-11 18:22:59 +03:00
Pascal Fischer
b9f82e2f8a [management] Buffer updateAccountPeers calls (#3644) 2025-04-11 17:21:05 +02:00
Pedro Maia Costa
fd2a21c65d [management] remove unnecessary access control middleware (#3650) 2025-04-11 10:43:59 +01:00
Maycon Santos
82d982b0ab [management,client] Add support to configurable prompt login (#3660) 2025-04-11 11:34:55 +02:00
Maycon Santos
9e24fe7701 [docs] Fix a few typos on table (#3658) 2025-04-10 17:57:39 +02:00
Pedro Maia Costa
e470701b80 [ci] include stash in pr template (#3657) 2025-04-10 16:30:44 +01:00
Viktor Liu
e3ce026355 [client] Fix race dns cleanup race condition (#3652) 2025-04-10 13:21:14 +02:00
Pascal Fischer
5ea2806663 [management] use permission modules (#3622) 2025-04-10 11:06:52 +02:00
Viktor Liu
d6b0673580 [client] Support CNAME in local resolver (#3646) 2025-04-10 10:38:47 +02:00
Pedro Maia Costa
14913cfa7a add git town config (#3555) 2025-04-09 20:18:52 +01:00
Viktor Liu
03f600b576 [client] Fallback to TCP if a truncated UDP response is received from upstream DNS (#3632) 2025-04-08 13:41:13 +02:00
Viktor Liu
192c97aa63 [client] Support IP fragmentation in userspace (#3639) 2025-04-08 12:49:14 +02:00
Maycon Santos
4db78db49a [misc] Update FreeBSD workflow (#3638)
Update FreeBSD release to 14.2 and download Go package directly since port wasn't finding the package to install
2025-04-08 09:15:09 +02:00
Viktor Liu
87e600a4f3 [client] Automatically register match domains for DNS routes (#3614) 2025-04-07 15:18:45 +02:00
Viktor Liu
6162aeb82d [client] Mark netbird data plane traffic to identify interface traffic correctly (#3623) 2025-04-07 13:14:56 +02:00
hakansa
1ba1e092ce [client] Enhance DNS forwarder to track resolved IPs with resource IDs on routing peers (#3620)
[client] Enhance DNS forwarder to track resolved IPs with resource IDs on routing peers (#3620)
2025-04-07 15:16:12 +08:00
hakansa
86dbb4ee4f [client] Add no-browser flag to login and up commands for SSO login control (#3610)
* [client] Add no-browser flag to login and up commands for SSO login control (#3610)
2025-04-07 14:39:53 +08:00
hakansa
4af177215f [client] Fix Status Recorder Route Removal Logic to Handle Dynamic Routes Correctly 2025-04-06 09:57:28 +08:00
Viktor Liu
df9c1b9883 [client] Improve TCP conn tracking (#3572) 2025-04-05 11:42:15 +02:00
Viktor Liu
5752bb78f2 [client] Fix missing inbound flows in Linux userspace mode with native router (#3624)
* Fix missing inbound flows in Linux userspace mode with native router

* Fix route enable/disable order for userspace mode
2025-04-05 11:41:31 +02:00
Maycon Santos
fbd783ad58 [client] Use the netbird logger for ice and grpc (#3603)
updates the logging implementation to use the netbird logger for both ICE and gRPC components. The key changes include:

- Introducing a gRPC logger configuration in util/log.go that integrates with the netbird logging setup.
- Updating the log hook in formatter/hook/hook.go to ensure a default caller is used when not set.
- Refactoring ICE agent and UDP multiplexers to use a unified logger via the new getLogger() method.
2025-04-04 18:30:47 +02:00
Viktor Liu
80702b9323 [client] Fix dns forwarder handling of requested record types (#3615) 2025-04-03 13:58:36 +02:00
Viktor Liu
09243a0fe0 [management] Remove remaining backend linux router limitation (#3589) 2025-04-01 21:29:57 +02:00
Maycon Santos
3658215747 [client] Force new user login on PKCE auth in CLI (#3604)
With this change, browser session won't be considered for cli authentication and credentials will be requested
2025-04-01 10:29:29 +02:00
Viktor Liu
48ffec95dd Improve local ip lookup (#3551)
- lower memory footprint in most cases
- increase accuracy
2025-03-31 10:05:57 +02:00
Pedro Maia Costa
cbec7bda80 [management] permission manager validate account access (#3444) 2025-03-30 17:08:22 +02:00
Zoltan Papp
21464ac770 [client] Fix close WireGuard watcher (#3598)
This PR fixes issues with closing the WireGuard watcher by adjusting its asynchronous invocation and synchronization.

Update tests in wg_watcher_test.go to launch the watcher in a goroutine and add a delay for timing.
Modify wg_watcher.go to run the periodic handshake check synchronously by removing the waitGroup and goroutine.
Enhance conn.go to wait on the watcher wait group during connection close and add a note for potential further synchronization
2025-03-28 20:12:31 +01:00
Zoltan Papp
ed5647028a [client] Prevent calling the onDisconnected callback in incorrect state (#3582)
Prevent calling the onDisconnected callback if the ICE connection has never been established

If call onDisconnected without onConnected then overwrite the relayed status in the conn priority variable.
2025-03-28 18:08:26 +01:00
Viktor Liu
29a6e5be71 [client] Stop flow grpc receiver properly (#3596) 2025-03-28 16:08:31 +01:00
Viktor Liu
6124e3b937 [client] Disable systemd-resolved default route explicitly on match domains only (#3584) 2025-03-28 11:14:32 +01:00
Maycon Santos
50f5cc48cd [management] Fix extended config when nil (#3593)
* Fix extended config when nil

* update integrations
2025-03-27 23:07:10 +01:00
Viktor Liu
101cce27f2 [client] Ensure status recorder is always initialized (#3588)
* Ensure status recorder is always initialized

* Add test

* Add subscribe test
2025-03-27 22:48:11 +01:00
Maycon Santos
a4f04f5570 [management] fix extend call and move config to types (#3575)
This PR fixes configuration inconsistencies and updates the store engine type usage throughout the management code. Key changes include:
- Replacing outdated server.Config references with types.Config and updating related flag variables (e.g. types.MgmtConfigPath).
- Converting engine constants (SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine) to use types.Engine for consistent type–safety.
- Adjusting various test and migration code paths to correctly reference the new configuration and engine types.
2025-03-27 13:04:50 +01:00
hakansa
fceb3ca392 [client] fix route handling for local peer state (#3586) 2025-03-27 19:31:04 +08:00
Bethuel Mmbaga
34d86c5ab8 [management] Sync account peers on network router group changes (#3573)
- Updates account peers when a group linked to a network router is modified
- Prevents group deletion if it's still being used by any network router
2025-03-27 12:19:22 +01:00
260 changed files with 12212 additions and 4838 deletions

27
.git-branches.toml Normal file
View File

@@ -0,0 +1,27 @@
# More info around this file at https://www.git-town.com/configuration-file
[branches]
main = "main"
perennials = []
perennial-regex = ""
[create]
new-branch-type = "feature"
push-new-branches = false
[hosting]
dev-remote = "origin"
# platform = ""
# origin-hostname = ""
[ship]
delete-tracking-branch = false
strategy = "squash-merge"
[sync]
feature-strategy = "merge"
perennial-strategy = "rebase"
prototype-strategy = "merge"
push-hook = true
tags = true
upstream = false

View File

@@ -2,6 +2,10 @@
## Issue ticket number and link
## Stack
<!-- branch-stack -->
### Checklist
- [ ] Is it a bug fix
- [ ] Is a typo/documentation fix

21
.github/workflows/git-town.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
name: Git Town
on:
pull_request:
branches:
- '**'
jobs:
git-town:
name: Display the branch stack
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: git-town/action@v1
with:
skip-single-stacks: true

View File

@@ -22,14 +22,20 @@ jobs:
with:
usesh: true
copyback: false
release: "14.1"
release: "14.2"
prepare: |
pkg install -y go pkgconf xorg
pkg install -y curl pkgconf xorg
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"
# -x - to print all executed commands
# -e - to faile on first error
run: |
set -e -x
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...

View File

@@ -146,6 +146,65 @@ jobs:
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [build-cache]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
id: go-env
run: |
echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@v4
id: cache-restore
with:
path: |
${{ steps.go-env.outputs.cache_dir }}
${{ steps.go-env.outputs.modcache_dir }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Run tests in container
env:
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
run: |
CONTAINER_GOCACHE="/root/.cache/go-build"
CONTAINER_GOMODCACHE="/go/pkg/mod"
docker run --rm \
--cap-add=NET_ADMIN \
--privileged \
-v $PWD:/app \
-w /app \
-v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \
-v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \
-e CGO_ENABLED=1 \
-e CI=true \
-e DOCKER_CI=true \
-e GOARCH=${GOARCH_TARGET} \
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
golang:1.23-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
'
test_relay:
name: "Relay / Unit"
needs: [build-cache]
@@ -179,13 +238,6 @@ jobs:
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
@@ -232,13 +284,6 @@ jobs:
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
@@ -286,13 +331,6 @@ jobs:
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
@@ -314,6 +352,7 @@ jobs:
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/...
@@ -353,13 +392,6 @@ jobs:
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
@@ -380,10 +412,11 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./...
-timeout 20m ./management/...
api_benchmark:
name: "Management / Benchmark (API)"
@@ -396,6 +429,33 @@ jobs:
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
run: docker network create promnet
- name: Start Prometheus Pushgateway
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
- name: Start Prometheus (for Pushgateway forwarding)
run: |
echo '
global:
scrape_interval: 15s
scrape_configs:
- job_name: "pushgateway"
static_configs:
- targets: ["pushgateway:9091"]
remote_write:
- url: ${{ secrets.GRAFANA_URL }}
basic_auth:
username: ${{ secrets.GRAFANA_USER }}
password: ${{ secrets.GRAFANA_API_KEY }}
' > prometheus.yml
docker run -d --name prometheus --network promnet \
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go
uses: actions/setup-go@v5
with:
@@ -420,13 +480,6 @@ jobs:
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
@@ -447,11 +500,13 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/...
api_integration_test:
@@ -489,13 +544,6 @@ jobs:
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
@@ -505,89 +553,8 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ]
runs-on: ubuntu-20.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
- name: Generate SystemOps Test bin
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
- name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
- name: Generate Engine Test bin
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
- run: chmod +x *testing.bin
- name: Run Shared Sock tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run SystemOps tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with file store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with sqlite store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum
only_warn: 1
golangci:

View File

@@ -178,6 +178,7 @@ jobs:
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
- name: Install modules
run: go mod tidy

View File

@@ -96,6 +96,20 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
binary: netbird-upload
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
@@ -409,6 +423,52 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-amd64
ids:
- netbird-upload
goarch: amd64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
ids:
- netbird-upload
goarch: arm64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm
ids:
- netbird-upload
goarch: arm
goarm: 6
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
@@ -475,7 +535,17 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- netbirdio/management:{{ .Version }}-debug-arm
- netbirdio/management:{{ .Version }}-debug-amd64
- name_template: netbirdio/upload:{{ .Version }}
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/upload:latest
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
brews:
- ids:
- default

View File

@@ -57,16 +57,16 @@
### Key features
| Connectivity | Management | Security | Automation | Platforms |
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
| | | | | <ul><li> - \[x] Docker </ul></li> |
| Connectivity | Management | Security | Automation| Platforms |
|----|----|----|----|----|
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
||||| <ul><li>- \[x] Docker</ui></li> |
### Quickstart with NetBird Cloud

View File

@@ -1,5 +1,6 @@
FROM alpine:3.21.3
RUN apk add --no-cache ca-certificates iptables ip6tables
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird
COPY netbird /usr/local/bin/netbird

View File

@@ -11,9 +11,12 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
const errCloseConnection = "Failed to close connection: %v"
@@ -84,16 +87,27 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag,
})
}
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Printf("Local file:\n%s\n", resp.GetPath())
cmd.Println(resp.GetPath())
if resp.GetUploadFailureReason() != "" {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
return nil
}
@@ -208,12 +222,15 @@ func runForDuration(cmd *cobra.Command, args []string) error {
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: debugSystemInfoFlag,
})
}
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
@@ -239,7 +256,15 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Println(resp.GetPath())
cmd.Printf("Local file:\n%s\n", resp.GetPath())
if resp.GetUploadFailureReason() != "" {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
return nil
}
@@ -326,3 +351,34 @@ func formatDuration(d time.Duration) string {
s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
var networkMap *mgmProto.NetworkMap
var err error
if connectClient != nil {
networkMap, err = connectClient.GetLatestNetworkMap()
if err != nil {
log.Warnf("Failed to get latest network map: %v", err)
}
}
bundleGenerator := debug.NewBundleGenerator(
debug.GeneratorDependencies{
InternalConfig: config,
StatusRecorder: recorder,
NetworkMap: networkMap,
LogFile: logFilePath,
},
debug.BundleConfig{
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
log.Errorf("Failed to generate debug bundle: %v", err)
return
}
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
}

39
client/cmd/debug_unix.go Normal file
View File

@@ -0,0 +1,39 @@
//go:build unix
package cmd
import (
"context"
"os"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
)
func SetupDebugHandler(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
) {
usr1Ch := make(chan os.Signal, 1)
signal.Notify(usr1Ch, syscall.SIGUSR1)
go func() {
for {
select {
case <-ctx.Done():
return
case <-usr1Ch:
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
go generateDebugBundle(config, recorder, connectClient, logFilePath)
}
}
}()
}

126
client/cmd/debug_windows.go Normal file
View File

@@ -0,0 +1,126 @@
package cmd
import (
"context"
"errors"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
)
const (
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
waitTimeout = 5 * time.Second
)
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
// Example usage with PowerShell:
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
// $evt.Set()
// $evt.Close()
func SetupDebugHandler(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
) {
env := os.Getenv(envListenEvent)
if env == "" {
return
}
listenEvent, err := strconv.ParseBool(env)
if err != nil {
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
return
}
if !listenEvent {
return
}
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
if err != nil {
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
return
}
// TODO: restrict access by ACL
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
if err != nil {
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
if err != nil {
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
return
}
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
} else {
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
return
}
}
if eventHandle == windows.InvalidHandle {
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
return
}
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
}
func waitForEvent(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
eventHandle windows.Handle,
) {
defer func() {
if err := windows.CloseHandle(eventHandle); err != nil {
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
}
}()
for {
if ctx.Err() != nil {
return
}
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
switch status {
case windows.WAIT_OBJECT_0:
log.Info("Received signal on debug event. Triggering debug bundle generation.")
// reset the event so it can be triggered again later (manual reset == 1)
if err := windows.ResetEvent(eventHandle); err != nil {
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
}
go generateDebugBundle(config, recorder, connectClient, logFilePath)
case uint32(windows.WAIT_TIMEOUT):
default:
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
select {
case <-time.After(5 * time.Second):
case <-ctx.Done():
return
}
}
}
}

View File

@@ -19,6 +19,10 @@ import (
"github.com/netbirdio/netbird/util"
)
func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
}
var loginCmd = &cobra.Command{
Use: "login",
Short: "login to the Netbird Management Service (first run)",
@@ -51,6 +55,9 @@ var loginCmd = &cobra.Command{
return err
}
// update host's static platform and system information
system.UpdateStaticInfo()
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
@@ -127,7 +134,7 @@ var loginCmd = &cobra.Command{
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
@@ -198,7 +205,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
}
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
@@ -212,19 +219,27 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return &tokenInfo, nil
}
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
}
cmd.Println("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" +
verificationURIComplete + " " + codeMsg)
if noBrowser {
cmd.Println("Use this URL to log in:\n\n" + verificationURIComplete + " " + codeMsg)
} else {
cmd.Println("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" +
verificationURIComplete + " " + codeMsg)
}
cmd.Println("")
if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
if !noBrowser {
if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
}
}

View File

@@ -22,6 +22,7 @@ import (
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/upload-server/types"
)
const (
@@ -39,6 +40,9 @@ const (
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
defaultBundleURL = "https://upload.debug.netbird.io" + types.GetURLPath
)
var (
@@ -75,6 +79,8 @@ var (
debugSystemInfoFlag bool
dnsRouteInterval time.Duration
blockLANAccess bool
debugUploadBundle bool
debugUploadBundleURL string
rootCmd = &cobra.Command{
Use: "netbird",
@@ -181,6 +187,8 @@ func init() {
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, defaultBundleURL, "Service URL to get an URL to upload the debug bundle")
}
// SetupCloseHandler handles SIGTERM signal and exits with success

View File

@@ -16,12 +16,17 @@ import (
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
)
func (p *program) Start(svc service.Service) error {
// Start should not block. Do the actual work async.
log.Info("starting Netbird service") //nolint
// Collect static system and platform information
system.UpdateStaticInfo()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer()
@@ -115,6 +120,7 @@ var runCmd = &cobra.Command{
ctx, cancel := context.WithCancel(cmd.Context())
SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, logFile)
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil {

View File

@@ -12,9 +12,11 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
@@ -32,7 +34,7 @@ import (
func startTestingServices(t *testing.T) string {
t.Helper()
config := &mgmt.Config{}
config := &types.Config{}
_, err := util.ReadJson("../testdata/management.json", config)
if err != nil {
t.Fatal(err)
@@ -67,7 +69,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis
}
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
@@ -90,13 +92,18 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil {
t.Fatal(err)
}

View File

@@ -32,12 +32,16 @@ const (
const (
dnsLabelsFlag = "extra-dns-labels"
noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login"
)
var (
foregroundMode bool
dnsLabels []string
dnsLabelsValidated domain.List
noBrowser bool
upCmd = &cobra.Command{
Use: "up",
@@ -65,6 +69,9 @@ func init() {
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`,
)
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
}
func upFunc(cmd *cobra.Command, args []string) error {
@@ -212,6 +219,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
}
@@ -349,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {

View File

@@ -113,17 +113,16 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -243,6 +242,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.router.DeleteDNATRule(rule)
}
// UpdateSet updates the set with the given prefixes
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -38,10 +38,12 @@ const (
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
matchSet = "--match-set"
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
matchSet = "--match-set"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
@@ -55,18 +57,18 @@ type ruleInfo struct {
}
type routeFilteringRuleParams struct {
Sources []netip.Prefix
Destination netip.Prefix
Source firewall.Network
Destination firewall.Network
Proto firewall.Protocol
SPort *firewall.Port
DPort *firewall.Port
Direction firewall.RuleDirection
Action firewall.Action
SetName string
}
type routeRules map[string][]string
// the ipset library currently does not support comments, so we use the name only (string)
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
type router struct {
@@ -115,6 +117,10 @@ func (r *router) init(stateManager *statemanager.Manager) error {
return fmt.Errorf("create containers: %w", err)
}
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
r.updateState()
return nil
@@ -123,7 +129,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
@@ -134,27 +140,28 @@ func (r *router) AddRouteFiltering(
return ruleKey, nil
}
var setName string
var source firewall.Network
if len(sources) > 1 {
setName = firewall.GenerateSetName(sources)
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
source.Set = firewall.NewPrefixSet(sources)
} else if len(sources) > 0 {
source.Prefix = sources[0]
}
params := routeFilteringRuleParams{
Sources: sources,
Source: source,
Destination: destination,
Proto: proto,
SPort: sPort,
DPort: dPort,
Action: action,
SetName: setName,
}
rule := genRouteFilteringRuleSpec(params)
rule, err := r.genRouteRuleSpec(params, sources)
if err != nil {
return nil, fmt.Errorf("generate route rule spec: %w", err)
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end
var err error
if action == firewall.ActionDrop {
// after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
@@ -177,17 +184,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID()
if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err)
}
delete(r.rules, ruleKey)
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("failed to remove ipset: %w", err)
}
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} else {
log.Debugf("route rule %s not found", ruleKey)
@@ -198,13 +201,26 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return nil
}
func (r *router) findSetNameInRule(rule []string) string {
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
return rule[i+3]
func (r *router) decrementSetCounter(rule []string) error {
sets := r.findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
}
}
return ""
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule []string) []string {
var sets []string
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
sets = append(sets, rule[i+3])
}
}
return sets
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
@@ -225,6 +241,8 @@ func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
log.Debugf("Deleted unused ipset %s", setName)
return nil
}
@@ -264,12 +282,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
log.Errorf("%v", err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -307,8 +327,10 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleKey)
} else {
log.Debugf("legacy forwarding rule %s not found", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
}
return nil
@@ -348,12 +370,16 @@ func (r *router) Reset() error {
if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err)
}
r.rules = make(map[string][]string)
if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanupDataPlaneMark(); err != nil {
merr = multierror.Append(merr, err)
}
r.rules = make(map[string][]string)
r.updateState()
return nberrors.FormatErrorOrNil(merr)
@@ -423,6 +449,57 @@ func (r *router) createContainers() error {
return nil
}
// setupDataPlaneMark configures the fwmark for the data plane
func (r *router) setupDataPlaneMark() error {
var merr *multierror.Error
preRule := []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
} else {
r.rules[markManglePre] = preRule
}
postRule := []string{
"-o", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
} else {
r.rules[markManglePost] = postRule
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) cleanupDataPlaneMark() error {
var merr *multierror.Error
if preRule, exists := r.rules[markManglePre]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
} else {
delete(r.rules, markManglePre)
}
}
if postRule, exists := r.rules[markManglePost]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
} else {
delete(r.rules, markManglePost)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
@@ -464,7 +541,7 @@ func (r *router) insertEstablishedRule(chain string) error {
}
func (r *router) addJumpRules() error {
// Jump to NAT chain
// Jump to nat chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat postrouting jump rule: %v", err)
@@ -538,12 +615,26 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
"-s", pair.Source.String(),
"-d", pair.Destination.String(),
)
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
if err != nil {
return fmt.Errorf("apply network -s: %w", err)
}
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
if err != nil {
return fmt.Errorf("apply network -d: %w", err)
}
rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
rule = append(rule,
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
// Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
// TODO: rollback ipset counter
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
}
@@ -561,6 +652,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} else {
log.Debugf("marking rule %s not found", ruleKey)
}
@@ -726,17 +821,21 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr)
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) {
var rule []string
if params.SetName != "" {
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
} else if len(params.Sources) > 0 {
source := params.Sources[0]
rule = append(rule, "-s", source.String())
sourceExp, err := r.applyNetwork("-s", params.Source, sources)
if err != nil {
return nil, fmt.Errorf("apply network -s: %w", err)
}
destExp, err := r.applyNetwork("-d", params.Destination, nil)
if err != nil {
return nil, fmt.Errorf("apply network -d: %w", err)
}
rule = append(rule, "-d", params.Destination.String())
rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
@@ -746,7 +845,47 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
rule = append(rule, "-j", actionToStr(params.Action))
return rule
return rule, nil
}
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
direction := "src"
if flag == "-d" {
direction = "dst"
}
if network.IsSet() {
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
}
if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil
}
// nolint:nilnil
return nil, nil
}
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
var merr *multierror.Error
for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
}
}
if merr == nil {
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
}
return nberrors.FormatErrorOrNil(merr)
}
func applyPort(flag string, port *firewall.Port) []string {

View File

@@ -46,7 +46,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 5. jump rule to PRE nat chain
// 6. static outbound masquerade rule
// 7. static return masquerade rule
require.Len(t, manager.rules, 7, "should have created rules map")
// 8. mangle prerouting mark rule
// 9. mangle postrouting mark rule
require.Len(t, manager.rules, 9, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -58,8 +60,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
pair := firewall.RouterPair{
ID: "abc",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.100.0/24"),
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")},
Masquerade: true,
}
@@ -330,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map
@@ -345,23 +347,29 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")
var source firewall.Network
if len(tt.sources) > 1 {
source.Set = firewall.NewPrefixSet(tt.sources)
} else if len(tt.sources) > 0 {
source.Prefix = tt.sources[0]
}
// Verify rule content
params := routeFilteringRuleParams{
Sources: tt.sources,
Destination: tt.destination,
Source: source,
Destination: firewall.Network{Prefix: tt.destination},
Proto: tt.proto,
SPort: tt.sPort,
DPort: tt.dPort,
Action: tt.action,
SetName: "",
}
expectedRule := genRouteFilteringRuleSpec(params)
expectedRule, err := r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec")
if tt.expectSet {
setName := firewall.GenerateSetName(tt.sources)
params.SetName = setName
expectedRule = genRouteFilteringRuleSpec(params)
setName := firewall.NewPrefixSet(tt.sources).HashedName()
expectedRule, err = r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec with set")
// Check if the set was created
_, exists := r.ipsetCounter.Get(setName)
@@ -376,3 +384,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
})
}
}
func TestFindSetNameInRule(t *testing.T) {
r := &router{}
testCases := []struct {
name string
rule []string
expected []string
}{
{
name: "Basic rule with two sets",
rule: []string{
"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src",
"-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT",
},
expected: []string{"nb-2e5a2a05", "nb-349ae051"},
},
{
name: "No sets",
rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"},
expected: []string{},
},
{
name: "Multiple sets with different positions",
rule: []string{
"-m", "set", "--match-set", "set1", "src", "-p", "tcp",
"-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT",
},
expected: []string{"set1", "set-abc123"},
},
{
name: "Boundary case - sequence appears at end",
rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"},
expected: []string{"final-set"},
},
{
name: "Incomplete pattern - missing set name",
rule: []string{"-p", "tcp", "-m", "set", "--match-set"},
expected: []string{},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := r.findSets(tc.rule)
if len(result) != len(tc.expected) {
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
return
}
for i, set := range result {
if set != tc.expected[i] {
t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set)
}
}
})
}
}

View File

@@ -1,13 +1,10 @@
package manager
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"net/netip"
"sort"
"strings"
log "github.com/sirupsen/logrus"
@@ -43,6 +40,18 @@ const (
// Action is the action to be taken on a rule
type Action int
// String returns the string representation of the action
func (a Action) String() string {
switch a {
case ActionAccept:
return "accept"
case ActionDrop:
return "drop"
default:
return "unknown"
}
}
const (
// ActionAccept is the action to accept a packet
ActionAccept Action = iota
@@ -50,6 +59,33 @@ const (
ActionDrop
)
// Network is a rule destination, either a set or a prefix
type Network struct {
Set Set
Prefix netip.Prefix
}
// String returns the string representation of the destination
func (d Network) String() string {
if d.Prefix.IsValid() {
return d.Prefix.String()
}
if d.IsSet() {
return d.Set.HashedName()
}
return "<invalid network>"
}
// IsSet returns true if the destination is a set
func (d Network) IsSet() bool {
return d.Set != Set{}
}
// IsPrefix returns true if the destination is a valid prefix
func (d Network) IsPrefix() bool {
return d.Prefix.IsValid()
}
// Manager is the high level abstraction of a firewall manager
//
// It declares methods which handle actions required by the
@@ -83,10 +119,9 @@ type Manager interface {
AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
destination Network,
proto Protocol,
sPort *Port,
dPort *Port,
sPort, dPort *Port,
action Action,
) (Rule, error)
@@ -119,6 +154,9 @@ type Manager interface {
// DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error
}
func GenKey(format string, pair RouterPair) string {
@@ -153,22 +191,6 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
return nil
}
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {

View File

@@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
result1 := manager.NewPrefixSet(prefixes1)
result2 := manager.NewPrefixSet(prefixes2)
if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
@@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("10.0.0.0/8"),
}
result := manager.GenerateSetName(prefixes)
result := manager.NewPrefixSet(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName())
if err != nil {
t.Fatalf("Error matching regex: %v", err)
}
@@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) {
})
t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.GenerateSetName([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{})
result1 := manager.NewPrefixSet([]netip.Prefix{})
result2 := manager.NewPrefixSet([]netip.Prefix{})
if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
@@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
result1 := manager.NewPrefixSet(prefixes1)
result2 := manager.NewPrefixSet(prefixes2)
if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)

View File

@@ -1,15 +1,13 @@
package manager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type RouterPair struct {
ID route.ID
Source netip.Prefix
Destination netip.Prefix
Source Network
Destination Network
Masquerade bool
Inverse bool
}

View File

@@ -0,0 +1,74 @@
package manager
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/netip"
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
)
type Set struct {
hash [4]byte
comment string
}
// String returns the string representation of the set: hashed name and comment
func (h Set) String() string {
if h.comment == "" {
return h.HashedName()
}
return h.HashedName() + ": " + h.comment
}
// HashedName returns the string representation of the hash
func (h Set) HashedName() string {
return fmt.Sprintf(
"nb-%s",
hex.EncodeToString(h.hash[:]),
)
}
// Comment returns the comment of the set
func (h Set) Comment() string {
return h.comment
}
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
func NewPrefixSet(prefixes []netip.Prefix) Set {
// sort for consistent naming
SortPrefixes(prefixes)
hash := sha256.New()
for _, src := range prefixes {
bytes, err := src.MarshalBinary()
if err != nil {
log.Warnf("failed to marshal prefix %s: %v", src, err)
}
hash.Write(bytes)
}
var set Set
copy(set.hash[:], hash.Sum(nil)[:4])
return set
}
// NewDomainSet generates a unique name for an ipset based on the given domains.
func NewDomainSet(domains domain.List) Set {
slices.Sort(domains)
hash := sha256.New()
for _, d := range domains {
hash.Write([]byte(d.PunycodeString()))
}
set := Set{
comment: domains.SafeString(),
}
copy(set.hash[:], hash.Sum(nil)[:4])
return set
}

View File

@@ -25,9 +25,10 @@ const (
chainNameInputRules = "netbird-acl-input-rules"
// filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"
chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
@@ -462,13 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting,
// Chain is created by route manager
// TODO: move creation to a common place
m.chainPrerouting = &nftables.Chain{
Name: chainNameManglePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
})
}
m.addFwmarkToForward(chainFwFilter)

View File

@@ -135,17 +135,16 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -242,7 +241,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
}
// Reset firewall to the default state
// Close closes the firewall manager
func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -359,6 +358,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.router.DeleteDNATRule(rule)
}
// UpdateSet updates the set with the given prefixes
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -289,7 +289,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
@@ -298,8 +298,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"),
Destination: netip.MustParsePrefix("10.0.0.0/24"),
Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
Masquerade: true,
}
err = manager.AddNatRule(pair)

View File

@@ -10,7 +10,6 @@ import (
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
@@ -44,9 +43,14 @@ const (
const refreshRulesMapError = "refresh rules map: %w"
var (
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
)
type setInput struct {
set firewall.Set
prefixes []netip.Prefix
}
type router struct {
conn *nftables.Conn
workTable *nftables.Table
@@ -54,7 +58,7 @@ type router struct {
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
@@ -100,6 +104,10 @@ func (r *router) init(workTable *nftables.Table) error {
return fmt.Errorf("create containers: %w", err)
}
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
return nil
}
@@ -159,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error {
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
return nil, fmt.Errorf("unable to list tables: %v", err)
}
for _, table := range tables {
@@ -196,15 +204,21 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT,
})
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePostrouting,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePrerouting,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
}
})
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
@@ -220,7 +234,83 @@ func (r *router) createContainers() error {
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
return fmt.Errorf("initialize tables: %v", err)
}
return nil
}
// setupDataPlaneMark configures the fwmark for the data plane
func (r *router) setupDataPlaneMark() error {
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
return errors.New("no mangle chains found")
}
ctNew := getCtNewExprs()
preExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
preExprs = append(preExprs, ctNew...)
preExprs = append(preExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
preNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: preExprs,
}
r.conn.AddRule(preNftRule)
postExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
postExprs = append(postExprs, ctNew...)
postExprs = append(postExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
postNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePostrouting],
Exprs: postExprs,
}
r.conn.AddRule(postNftRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
return nil
@@ -230,7 +320,7 @@ func (r *router) createContainers() error {
func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
@@ -245,23 +335,29 @@ func (r *router) AddRouteFiltering(
chain := r.chains[chainNameRoutingFw]
var exprs []expr.Any
var source firewall.Network
switch {
case len(sources) == 1 && sources[0].Bits() == 0:
// If it's 0.0.0.0/0, we don't need to add any source matching
case len(sources) == 1:
// If there's only one source, we can use it directly
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
source.Prefix = sources[0]
default:
// If there are multiple sources, create or get an ipset
var err error
exprs, err = r.getIpSetExprs(sources, exprs)
if err != nil {
return nil, fmt.Errorf("get ipset expressions: %w", err)
}
// If there are multiple sources, use a set
source.Set = firewall.NewPrefixSet(sources)
}
// Handle destination
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
sourceExp, err := r.applyNetwork(source, sources, true)
if err != nil {
return nil, fmt.Errorf("apply source: %w", err)
}
exprs = append(exprs, sourceExp...)
destExp, err := r.applyNetwork(destination, nil, false)
if err != nil {
return nil, fmt.Errorf("apply destination: %w", err)
}
exprs = append(exprs, destExp...)
// Handle protocol
if proto != firewall.ProtocolALL {
@@ -305,39 +401,27 @@ func (r *router) AddRouteFiltering(
rule = r.conn.AddRule(rule)
}
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
r.rules[string(ruleKey)] = rule
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources)
ref, err := r.ipsetCounter.Increment(setName, sources)
func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
set: set,
prefixes: prefixes,
})
if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
return nil, fmt.Errorf("create or get ipset: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
)
return exprs, nil
return getIpSetExprs(ref, isSource)
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
@@ -356,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
return fmt.Errorf("delete: %w", err)
}
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("decrement ipset reference: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
if err := r.decrementSetCounter(nftRule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources)
prefixes := firewall.MergeIPRanges(input.prefixes)
set := &nftables.Set{
Name: setName,
Table: r.workTable,
nfset := &nftables.Set{
Name: setName,
Comment: input.set.Comment(),
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: nftables.TypeIPAddr,
}
elements := convertPrefixesToSet(prefixes)
if err := r.conn.AddSet(nfset, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return nfset, nil
}
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement
for _, prefix := range sources {
for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
@@ -407,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
}
if err := r.conn.AddSet(set, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return set, nil
return elements
}
// calculateLastIP determines the last IP in a given prefix.
@@ -442,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte {
return b
}
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
r.conn.DelSet(set)
func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
r.conn.DelSet(nfset)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
@@ -452,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
return nil
}
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
return lookup.SetName
func (r *router) decrementSetCounter(rule *nftables.Rule) error {
sets := r.findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
}
}
return ""
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule *nftables.Rule) []string {
var sets []string
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
sets = append(sets, lookup.SetName)
}
}
return sets
}
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
@@ -500,7 +599,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
// TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
}
return nil
@@ -508,8 +608,15 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
op := expr.CmpOpEq
if pair.Inverse {
@@ -517,26 +624,6 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
}
exprs := []expr.Any{
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
@@ -547,6 +634,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
Data: ifname(r.wgIface.Name()),
},
}
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs = append(exprs, getCtNewExprs()...)
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
@@ -576,9 +666,11 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
// Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNamePrerouting],
Chain: r.chains[chainNameManglePrerouting],
Exprs: exprs,
UserData: []byte(ruleKey),
})
@@ -659,8 +751,15 @@ func (r *router) addPostroutingRules() error {
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
exprs := []expr.Any{
&expr.Counter{},
@@ -669,7 +768,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
},
}
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
@@ -682,7 +782,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Exprs: expression,
Exprs: exprs,
UserData: []byte(ruleKey),
})
return nil
@@ -697,11 +797,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
}
return nil
@@ -912,12 +1014,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -925,10 +1029,10 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
// TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil
}
@@ -936,16 +1040,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
if err != nil {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else {
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
log.Debugf("prerouting rule %s not found", ruleKey)
}
return nil
@@ -957,7 +1064,7 @@ func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
return fmt.Errorf(" unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
@@ -1231,13 +1338,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr)
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32
if source {
offset = 12 // src offset
} else {
offset = 16 // dst offset
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
if err != nil {
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
}
elements := convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,
setPrefixes []netip.Prefix,
isSource bool,
) ([]expr.Any, error) {
if network.IsSet() {
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
if err != nil {
return nil, fmt.Errorf("source: %w", err)
}
return exprs, nil
}
if network.IsPrefix() {
return applyPrefix(network.Prefix, isSource), nil
}
return nil, nil
}
// applyPrefix generates nftables expressions for a CIDR prefix
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
}
ones := prefix.Bits()
@@ -1324,3 +1472,48 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
return exprs
}
func getCtNewExprs() []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
}
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
}
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
}, nil
}

View File

@@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
}
// Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
// Combine all expressions in the correct order
// nolint:gocritic
@@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0
for _, chain := range rtr.chains {
if chain.Name == chainNamePrerouting {
if chain.Name == chainNameManglePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
@@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
// Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
@@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
// Verify the rule was removed
found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
@@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.GenerateSetName(tt.sources)
set, err := r.createIpSet(setName, tt.sources)
setName := firewall.NewPrefixSet(tt.sources).HashedName()
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
if err != nil {
t.Logf("Failed to create IP set: %v", err)
printNftSets()

View File

@@ -15,8 +15,8 @@ var (
Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.200.0/24"),
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: false,
},
},
@@ -24,8 +24,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.200.0/24"),
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true,
},
},
@@ -40,8 +40,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.200.0/24"),
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true,
},
},

View File

@@ -12,7 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Reset firewall to the default state
// Close cleans up the firewall manager by removing all rules and closing trackers
func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -22,7 +21,7 @@ const (
firewallRuleName = "Netbird"
)
// Reset firewall to the default state
// Close cleans up the firewall manager by removing all rules and closing trackers
func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -32,17 +31,14 @@ func (m *Manager) Close(*statemanager.Manager) error {
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
}
if fwder := m.forwarder.Load(); fwder != nil {

View File

@@ -1,7 +1,6 @@
package conntrack
import (
"context"
"net/netip"
"testing"
@@ -12,7 +11,7 @@ import (
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {

View File

@@ -189,7 +189,7 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}

View File

@@ -23,11 +23,11 @@ const (
)
const (
TCPSyn uint8 = 0x02
TCPAck uint8 = 0x10
TCPFin uint8 = 0x01
TCPSyn uint8 = 0x02
TCPRst uint8 = 0x04
TCPPush uint8 = 0x08
TCPAck uint8 = 0x10
TCPUrg uint8 = 0x20
)
@@ -41,7 +41,7 @@ const (
)
// TCPState represents the state of a TCP connection
type TCPState int
type TCPState int32
func (s TCPState) String() string {
switch s {
@@ -89,22 +89,25 @@ const (
// TCPConnTrack represents a TCP connection state
type TCPConnTrack struct {
BaseConnTrack
SourcePort uint16
DestPort uint16
State TCPState
established atomic.Bool
tombstone atomic.Bool
sync.RWMutex
SourcePort uint16
DestPort uint16
state atomic.Int32
tombstone atomic.Bool
}
// IsEstablished safely checks if connection is established
func (t *TCPConnTrack) IsEstablished() bool {
return t.established.Load()
// GetState safely retrieves the current state
func (t *TCPConnTrack) GetState() TCPState {
return TCPState(t.state.Load())
}
// SetEstablished safely sets the established state
func (t *TCPConnTrack) SetEstablished(state bool) {
t.established.Store(state)
// SetState safely updates the current state
func (t *TCPConnTrack) SetState(state TCPState) {
t.state.Store(int32(state))
}
// CompareAndSwapState atomically changes the state from old to new if current == old
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
return t.state.CompareAndSwap(int32(old), int32(newState))
}
// IsTombstone safely checks if the connection is marked for deletion
@@ -125,13 +128,17 @@ type TCPTracker struct {
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
timeout time.Duration
waitTimeout time.Duration
flowLogger nftypes.FlowLogger
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
waitTimeout := TimeWaitTimeout
if timeout == 0 {
timeout = DefaultTCPTimeout
} else {
waitTimeout = timeout / 45
}
ctx, cancel := context.WithCancel(context.Background())
@@ -142,6 +149,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
timeout: timeout,
waitTimeout: waitTimeout,
flowLogger: flowLogger,
}
@@ -149,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker
}
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
@@ -162,12 +170,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
t.mutex.RUnlock()
if exists {
conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.Unlock()
conn.UpdateCounters(direction, size)
t.updateState(key, conn, flags, direction, size)
return key, true
}
@@ -175,7 +178,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
@@ -183,14 +186,14 @@ func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort u
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists {
if exists || flags&TCPSyn == 0 {
return
}
@@ -205,12 +208,11 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
DestPort: dstPort,
}
conn.established.Store(false)
conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew))
t.logger.Trace("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction == nftypes.Egress)
conn.UpdateCounters(direction, size)
t.updateState(key, conn, flags, direction, size)
t.mutex.Lock()
t.connections[key] = conn
@@ -220,7 +222,7 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
@@ -232,134 +234,125 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
if !exists || conn.IsTombstone() {
return false
}
// Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
return t.handleRst(key, conn, size)
currentState := conn.GetState()
if !t.isValidStateForFlags(currentState, flags) {
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
// allow all flags for established for now
if currentState == TCPStateEstablished {
return true
}
return false
}
conn.Lock()
t.updateState(key, conn, flags, false)
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock()
conn.UpdateCounters(nftypes.Ingress, size)
return isEstablished || isValidState
}
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, size int) bool {
if conn.IsTombstone() {
return true
}
conn.Lock()
conn.SetTombstone()
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.updateState(key, conn, flags, nftypes.Ingress, size)
return true
}
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
conn.UpdateLastSeen()
conn.UpdateCounters(packetDir, size)
state := conn.State
defer func() {
if state != conn.State {
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
currentState := conn.GetState()
if flags&TCPRst != 0 {
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
conn.SetTombstone()
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}()
return
}
switch state {
var newState TCPState
switch currentState {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent
if conn.Direction == nftypes.Egress {
newState = TCPStateSynSent
} else {
newState = TCPStateSynReceived
}
}
case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound {
conn.State = TCPStateEstablished
conn.SetEstablished(true)
if packetDir != conn.Direction {
newState = TCPStateEstablished
} else {
// Simultaneous open
conn.State = TCPStateSynReceived
newState = TCPStateSynReceived
}
}
case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
conn.State = TCPStateEstablished
conn.SetEstablished(true)
if packetDir == conn.Direction {
newState = TCPStateEstablished
}
}
case TCPStateEstablished:
if flags&TCPFin != 0 {
if isOutbound {
conn.State = TCPStateFinWait1
if packetDir == conn.Direction {
newState = TCPStateFinWait1
} else {
conn.State = TCPStateCloseWait
newState = TCPStateCloseWait
}
conn.SetEstablished(false)
} else if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateFinWait1:
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
conn.State = TCPStateClosing
case flags&TCPFin != 0:
conn.State = TCPStateFinWait2
case flags&TCPAck != 0:
conn.State = TCPStateFinWait2
case flags&TCPRst != 0:
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
if packetDir != conn.Direction {
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
newState = TCPStateClosing
case flags&TCPFin != 0:
newState = TCPStateClosing
case flags&TCPAck != 0:
newState = TCPStateFinWait2
}
}
case TCPStateFinWait2:
if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
newState = TCPStateTimeWait
}
case TCPStateClosing:
if flags&TCPAck != 0 {
conn.State = TCPStateTimeWait
// Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
newState = TCPStateTimeWait
}
case TCPStateCloseWait:
if flags&TCPFin != 0 {
conn.State = TCPStateLastAck
newState = TCPStateLastAck
}
case TCPStateLastAck:
if flags&TCPAck != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
newState = TCPStateClosed
}
}
// Send close event for gracefully closed connections
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
switch newState {
case TCPStateTimeWait:
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("TCP connection %s closed gracefully", key)
}
}
}
@@ -369,18 +362,22 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
if flags&TCPRst != 0 {
if state == TCPStateSynSent {
return flags&TCPAck != 0
}
return true
}
switch state {
case TCPStateNew:
return flags&TCPSyn != 0 && flags&TCPAck == 0
case TCPStateSynSent:
// TODO: support simultaneous open
return flags&TCPSyn != 0 && flags&TCPAck != 0
case TCPStateSynReceived:
return flags&TCPAck != 0
case TCPStateEstablished:
if flags&TCPRst != 0 {
return true
}
return flags&TCPAck != 0
case TCPStateFinWait1:
return flags&TCPFin != 0 || flags&TCPAck != 0
@@ -397,9 +394,7 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
case TCPStateLastAck:
return flags&TCPAck != 0
case TCPStateClosed:
// Accept retransmitted ACKs in closed state
// This is important because the final ACK might be lost
// and the peer will retransmit their FIN-ACK
// Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK
return flags&TCPAck != 0
}
return false
@@ -430,23 +425,24 @@ func (t *TCPTracker) cleanup() {
}
var timeout time.Duration
switch {
case conn.State == TCPStateTimeWait:
timeout = TimeWaitTimeout
case conn.IsEstablished():
currentState := conn.GetState()
switch currentState {
case TCPStateTimeWait:
timeout = t.waitTimeout
case TCPStateEstablished:
timeout = t.timeout
default:
timeout = TCPHandshakeTimeout
}
if conn.timeoutExceeded(timeout) {
// Return IPs to pool
delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
// event already handled by state change
if conn.State != TCPStateTimeWait {
if currentState != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}

View File

@@ -0,0 +1,83 @@
package conntrack
import (
"net/netip"
"testing"
"time"
)
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
defer tracker.Close()
// Pre-populate with expired connections
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
}

View File

@@ -5,6 +5,7 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -124,9 +125,6 @@ func TestTCPStateMachine(t *testing.T) {
// Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets
// The connection will be cleaned up by timeout
},
},
{
@@ -217,97 +215,446 @@ func TestRSTHandling(t *testing.T) {
conn := tracker.connections[key]
if tt.wantValid {
require.NotNil(t, conn)
require.Equal(t, TCPStateClosed, conn.State)
require.False(t, conn.IsEstablished())
require.Equal(t, TCPStateClosed, conn.GetState())
}
})
}
}
func TestTCPRetransmissions(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test SYN retransmission
t.Run("SYN Retransmission", func(t *testing.T) {
// Initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Retransmit SYN (should not affect the state machine)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Verify we're still in SYN-SENT state
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateSynSent, conn.GetState())
// Complete the handshake
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Verify we're in ESTABLISHED state
require.Equal(t, TCPStateEstablished, conn.GetState())
})
// Test ACK retransmission in established state
t.Run("ACK Retransmission", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateEstablished, conn.GetState())
// Retransmit ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// State should remain ESTABLISHED
require.Equal(t, TCPStateEstablished, conn.GetState())
})
// Test FIN retransmission
t.Run("FIN Retransmission", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Retransmit FIN (should not change state)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
})
}
func TestTCPDataTransfer(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Data Transfer", func(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
// Receive ACK for data
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
require.True(t, valid)
// Receive data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
require.True(t, valid)
// Send ACK for received data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
// State should remain ESTABLISHED
require.Equal(t, TCPStateEstablished, conn.GetState())
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
})
}
func TestTCPHalfClosedConnections(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test half-closed connection: local end closes, remote end continues sending data
t.Run("Local Close, Remote Data", func(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Remote end can still send data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
require.True(t, valid)
// We can still ACK their data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Receive FIN from remote end
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// State should remain TIME-WAIT (waiting for possible retransmissions)
require.Equal(t, TCPStateTimeWait, conn.GetState())
})
// Test half-closed connection: remote end closes, local end continues sending data
t.Run("Remote Close, Local Data", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Receive FIN from remote
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// We can still send data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
// Remote can still ACK our data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Send our FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Receive final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateClosed, conn.GetState())
})
}
func TestTCPAbnormalSequences(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test handling of unsolicited RST in various states
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
// Send SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive unsolicited RST (without proper ACK)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
// Receive RST with proper ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.Equal(t, TCPStateClosed, conn.GetState())
require.True(t, conn.IsTombstone())
})
}
func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Connection Timeout", func(t *testing.T) {
// Establish a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateEstablished, conn.GetState())
// Wait for the connection to timeout
time.Sleep(2 * shortTimeout)
// Force cleanup
tracker.cleanup()
// Connection should be removed
_, exists := tracker.connections[key]
require.False(t, exists, "Connection should be removed after timeout")
})
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Complete the connection close to enter TIME_WAIT
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// TIME_WAIT should have its own timeout value (usually 2*MSL)
// For the test, we're using a short timeout
time.Sleep(2 * shortTimeout)
tracker.cleanup()
// Connection should be removed
_, exists := tracker.connections[key]
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
})
}
func TestSynFlood(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
basePort := uint16(10000)
dstPort := uint16(80)
// Create a large number of SYN packets to simulate a SYN flood
for i := uint16(0); i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
}
// Check that we're tracking all connections
require.Equal(t, 1000, len(tracker.connections))
// Now simulate SYN timeout
var oldConns int
tracker.mutex.Lock()
for _, conn := range tracker.connections {
if conn.GetState() == TCPStateSynSent {
// Make the connection appear old
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
oldConns++
}
}
tracker.mutex.Unlock()
require.Equal(t, 1000, oldConns)
// Run cleanup
tracker.cleanup()
// Check that stale connections were cleaned up
require.Equal(t, 0, len(tracker.connections))
}
func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := netip.MustParseAddr("100.64.0.1")
serverIP := netip.MustParseAddr("100.64.0.2")
clientPort := uint16(12345)
serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
key := ConnKey{
SrcIP: clientIP,
DstIP: serverIP,
SrcPort: clientPort,
DstPort: serverPort,
}
tracker.mutex.RLock()
conn := tracker.connections[key]
tracker.mutex.RUnlock()
require.NotNil(t, conn)
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
// 2. Server sends SYN-ACK response
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer
// Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
// Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
// Server sends data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
// Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState())
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
}
// Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
}

View File

@@ -165,7 +165,7 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}

View File

@@ -4,7 +4,9 @@ import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
@@ -17,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -29,8 +32,10 @@ const (
)
type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
logger *nblog.Logger
flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
@@ -167,3 +172,35 @@ func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
}
return addr.AsSlice()
}
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
key := buildKey(srcIP, dstIP, srcPort, dstPort)
f.ruleIdMap.LoadOrStore(key, ruleID)
}
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
return value.([]byte), true
}
return nil, false
}
func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return
}
f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort))
}
func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey {
return conntrack.ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
}

View File

@@ -25,7 +25,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
@@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
defer func() {
if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err)
f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err)
}
}()
@@ -52,36 +52,37 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
rxBytes := pkt.Size()
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return
f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0
}
response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err)
f.logger.Error("forwarder: Failed to read ICMP response: %v", err)
}
return
return 0
}
ipHdr := make([]byte, header.IPv4MinimumSize)
@@ -100,28 +101,54 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err)
f.logger.Error("forwarder: Failed to inject ICMP response: %v", err)
return
return 0
}
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)
}
// sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
f.flowLogger.StoreEvent(nftypes.EventFields{
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) {
var rxPackets, txPackets uint64
if rxBytes > 0 {
rxPackets = 1
}
if txBytes > 0 {
txPackets = 1
}
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.ICMP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourceIP: srcIp,
DestIP: dstIp,
ICMPType: icmpType,
ICMPCode: icmpCode,
// TODO: get packets/bytes
})
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
}
if typ == nftypes.TypeStart {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
fields.RuleID = ruleId
}
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
}
f.flowLogger.StoreEvent(fields)
}

View File

@@ -6,8 +6,10 @@ import (
"io"
"net"
"net/netip"
"sync"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -23,11 +25,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
}
}()
@@ -65,67 +67,97 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}()
// Create context for managing the proxy goroutines
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
errChan := make(chan error, 2)
go func() {
_, err := io.Copy(outConn, inConn)
errChan <- err
}()
go func() {
_, err := io.Copy(inConn, outConn)
errChan <- err
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err)
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesFromInToOut int64 // bytes from client to server (tx for client)
bytesFromOutToIn int64 // bytes from server to client (rx for client)
errInToOut error
errOutToIn error
)
go func() {
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
cancel()
wg.Done()
}()
go func() {
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
cancel()
wg.Done()
}()
wg.Wait()
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut)
}
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
return
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn)
}
}
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourceIP: srcIp,
DestIP: dstIp,
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
if typ == nftypes.TypeStart {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
fields.RuleID = ruleId
}
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
}
f.flowLogger.StoreEvent(fields)

View File

@@ -149,11 +149,11 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
}
}()
@@ -199,7 +199,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
}
f.udpForwarder.conns[id] = pConn
@@ -212,68 +211,94 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
}
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
defer func() {
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
pConn.cancel()
if err := pConn.conn.Close(); err != nil {
if err := pConn.conn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
}
if err := pConn.outConn.Close(); err != nil {
if err := pConn.outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
ep.Close()
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
}()
errChan := make(chan error, 2)
var wg sync.WaitGroup
wg.Add(2)
var txBytes, rxBytes int64
var outboundErr, inboundErr error
// outbound->inbound: copy from pConn.conn to pConn.outConn
go func() {
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
defer wg.Done()
txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}()
// inbound->outbound: copy from pConn.outConn to pConn.conn
go func() {
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
defer wg.Done()
rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
return
wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr)
}
if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr)
}
var rxPackets, txPackets uint64
if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
rxPackets = udpStats.PacketsSent.Value()
txPackets = udpStats.PacketsReceived.Value()
}
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets)
}
// sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourceIP: srcIp,
DestIP: dstIp,
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
if typ == nftypes.TypeStart {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
fields.RuleID = ruleId
}
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
}
f.flowLogger.StoreEvent(fields)
@@ -288,18 +313,20 @@ func (c *udpPacketConn) getIdleDuration() time.Duration {
return time.Since(lastSeen)
}
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
// copy reads from src and writes to dst.
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) {
bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp)
buffer := *bufp
var totalBytes int64 = 0
for {
if ctx.Err() != nil {
return ctx.Err()
return totalBytes, ctx.Err()
}
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
return totalBytes, fmt.Errorf("set read deadline: %w", err)
}
n, err := src.Read(buffer)
@@ -307,14 +334,15 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
if isTimeout(err) {
continue
}
return fmt.Errorf("read from %s: %w", direction, err)
return totalBytes, fmt.Errorf("read from %s: %w", direction, err)
}
_, err = dst.Write(buffer[:n])
nWritten, err := dst.Write(buffer[:n])
if err != nil {
return fmt.Errorf("write to %s: %w", direction, err)
return totalBytes, fmt.Errorf("write to %s: %w", direction, err)
}
totalBytes += int64(nWritten)
c.updateLastSeen()
}
}

View File

@@ -14,8 +14,13 @@ import (
type localIPManager struct {
mu sync.RWMutex
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
ipv4Bitmap [1 << 16]uint32
// fixed-size high array for upper byte of a IPv4 address
ipv4Bitmap [256]*ipv4LowBitmap
}
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
type ipv4LowBitmap struct {
bitmap [8192]uint32
}
func newLocalIPManager() *localIPManager {
@@ -27,35 +32,59 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
if ipv4 == nil {
return
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
m.ipv4Bitmap[high] |= 1 << (low % 32)
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
index := low / 32
bit := low % 32
if m.ipv4Bitmap[high] == nil {
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
}
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := (uint16(ip[0]) << 8) | uint16(ip[1])
low := (uint16(ip[2]) << 8) | uint16(ip[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
}
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
if ipv4 := ip.To4(); ipv4 != nil {
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
if int(high) >= len(*newIPv4Bitmap) {
return fmt.Errorf("invalid IPv4 address: %s", ip)
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
ipStr := ip.String()
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
ipStr := ipv4.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
newIPv4Bitmap[high] |= 1 << (low % 32)
}
}
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := uint16(ip[0])
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
if m.ipv4Bitmap[high] == nil {
return false
}
index := low / 32
bit := low % 32
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
}
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -73,7 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1
continue
}
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
@@ -86,14 +115,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}
}()
var newIPv4Bitmap [1 << 16]uint32
var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[string]struct{})
var ipv4Addresses []string
// 127.0.0.0/8
high := uint16(127) << 8
for i := uint16(0); i < 256; i++ {
newIPv4Bitmap[high|i] = 0xffffffff
newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := 0; i < 8192; i++ {
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
}
if iface != nil {
@@ -120,12 +149,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
if !ip.Is4() {
return false
}
m.mu.RLock()
defer m.mu.RUnlock()
if ip.Is4() {
return m.checkBitmapBit(ip.AsSlice())
}
return false
return m.checkBitmapBit(ip.AsSlice())
}

View File

@@ -77,6 +77,18 @@ func TestLocalIPManager(t *testing.T) {
testIP: netip.MustParseAddr("192.168.1.2"),
expected: false,
},
{
name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.33"),
expected: false,
},
{
name: "IPv6 address",
setupAddr: wgaddr.Address{
@@ -192,10 +204,8 @@ func BenchmarkIPChecks(b *testing.B) {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
}
// Setup bitmap version
bitmapManager := &localIPManager{
ipv4Bitmap: [1 << 16]uint32{},
}
// Setup bitmap
bitmapManager := newLocalIPManager()
for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip)
}
@@ -248,7 +258,7 @@ func BenchmarkWGPosition(b *testing.B) {
// Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
bm := newLocalIPManager()
bm.setBitmapBit(wgIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -257,7 +267,7 @@ func BenchmarkWGPosition(b *testing.B) {
})
b.Run("WG_Last", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
bm := newLocalIPManager()
// Fill with other IPs first
for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))

View File

@@ -29,14 +29,15 @@ func (r *PeerRule) ID() string {
}
type RouteRule struct {
id string
mgmtId []byte
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
id string
mgmtId []byte
sources []netip.Prefix
dstSet firewall.Set
destinations []netip.Prefix
proto firewall.Protocol
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
}
// ID returns the rule id

View File

@@ -198,12 +198,12 @@ func TestTracePacket(t *testing.T) {
m.forwarder.Store(&forwarder.Forwarder{})
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
@@ -222,12 +222,12 @@ func TestTracePacket(t *testing.T) {
m.nativeRouter.Store(false)
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
@@ -245,7 +245,7 @@ func TestTracePacket(t *testing.T) {
m.nativeRouter.Store(true)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
@@ -263,7 +263,7 @@ func TestTracePacket(t *testing.T) {
m.routingEnabled.Store(false)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
@@ -425,8 +425,8 @@ func TestTracePacket(t *testing.T) {
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
"100.10.0.100 should be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
"172.17.0.2 should not be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")),
"192.168.17.2 should not be recognized as a local IP")
pb := tc.packetBuilder()

View File

@@ -49,10 +49,10 @@ var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule
type RouteRules []RouteRule
type RouteRules []*RouteRule
func (r RouteRules) Sort() {
slices.SortStableFunc(r, func(a, b RouteRule) int {
slices.SortStableFunc(r, func(a, b *RouteRule) int {
// Deny rules come first
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
return -1
@@ -99,6 +99,8 @@ type Manager struct {
forwarder atomic.Pointer[forwarder.Forwarder]
logger *nblog.Logger
flowLogger nftypes.FlowLogger
blockRule firewall.Rule
}
// decoder for packages
@@ -201,41 +203,35 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}
}
if err := m.blockInvalidRouted(iface); err != nil {
log.Errorf("failed to block invalid routed traffic: %v", err)
}
if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err)
}
return m, nil
}
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
if m.forwarder.Load() == nil {
return nil
}
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
if err != nil {
return fmt.Errorf("parse wireguard network: %w", err)
return nil, fmt.Errorf("parse wireguard network: %w", err)
}
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
if _, err := m.AddRouteFiltering(
rule, err := m.addRouteFiltering(
nil,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
wgPrefix,
firewall.Network{Prefix: wgPrefix},
firewall.ProtocolALL,
nil,
nil,
firewall.ActionDrop,
); err != nil {
return fmt.Errorf("block wg nte : %w", err)
)
if err != nil {
return nil, fmt.Errorf("block wg nte : %w", err)
}
// TODO: Block networks that we're a client of
return nil
return rule, nil
}
func (m *Manager) determineRouting() error {
@@ -413,10 +409,23 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) addRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
@@ -426,34 +435,39 @@ func (m *Manager) AddRouteFiltering(
ruleID := uuid.New().String()
rule := RouteRule{
// TODO: consolidate these IDs
id: ruleID,
mgmtId: id,
sources: sources,
destination: destination,
proto: proto,
srcPort: sPort,
dstPort: dPort,
action: action,
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
proto: proto,
srcPort: sPort,
dstPort: dPort,
action: action,
}
if destination.IsPrefix() {
rule.destinations = []netip.Prefix{destination.Prefix}
}
m.mutex.Lock()
m.routeRules = append(m.routeRules, rule)
m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort()
m.mutex.Unlock()
return &rule, nil
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteRouteRule(rule)
}
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule)
}
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := rule.ID()
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == ruleID
})
if idx < 0 {
@@ -509,6 +523,52 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteDNATRule(rule)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.UpdateSet(set, prefixes)
}
m.mutex.Lock()
defer m.mutex.Unlock()
var matches []*RouteRule
for _, rule := range m.routeRules {
if rule.dstSet == set {
matches = append(matches, rule)
}
}
if len(matches) == 0 {
return fmt.Errorf("no route rule found for set: %s", set)
}
destinations := matches[0].destinations
for _, prefix := range prefixes {
if prefix.Addr().Is4() {
destinations = append(destinations, prefix)
}
}
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
cmp := a.Addr().Compare(b.Addr())
if cmp != 0 {
return cmp
}
return a.Bits() - b.Bits()
})
destinations = slices.Compact(destinations)
for _, rule := range matches {
rule.destinations = destinations
}
log.Debugf("updated set %s to prefixes %v", set.HashedName(), destinations)
return nil
}
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size)
@@ -658,7 +718,8 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if !m.isValidPacket(d, packetData) {
valid, fragment := m.isValidPacket(d, packetData)
if !valid {
return true
}
@@ -668,6 +729,13 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
return true
}
// TODO: pass fragments of routed packets to forwarder
if fragment {
m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
return false
}
// For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
@@ -678,7 +746,7 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
}
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData, size)
}
// handleLocalTraffic handles local traffic.
@@ -739,7 +807,7 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
// handleRoutedTraffic handles routed traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
// Drop if routing is disabled
if !m.routingEnabled.Load() {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
@@ -749,13 +817,15 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
// Pass to native stack if native router is enabled or forced
if m.nativeRouter.Load() {
m.trackInbound(d, srcIP, dstIP, nil, size)
return false
}
proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
if !pass {
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
@@ -770,6 +840,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
SourcePort: srcPort,
DestPort: dstPort,
// TODO: icmp type/code
RxPackets: 1,
RxBytes: uint64(size),
})
return true
}
@@ -779,8 +851,11 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
if fwd == nil {
m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
} else {
fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID)
if err := fwd.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject routed packet: %v", err)
fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort)
}
}
@@ -812,17 +887,32 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
}
}
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
// isValidPacket checks if the packet is valid.
// It returns true, false if the packet is valid and not a fragment.
// It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Trace("couldn't decode packet, err: %s", err)
return false
return false, false
}
if len(d.decoded) < 2 {
m.logger.Trace("packet doesn't have network and transport layers")
return false
l := len(d.decoded)
// L3 and L4 are mandatory
if l >= 2 {
return true, false
}
return true
// Fragments are also valid
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
ip4 := d.ip4
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
return true, true
}
}
m.logger.Trace("packet doesn't have network and transport layers")
return false, false
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
@@ -962,8 +1052,15 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol
return nil, false
}
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
if !rule.destination.Contains(dstAddr) {
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
destMatched := false
for _, dst := range rule.destinations {
if dst.Contains(dstAddr) {
destMatched = true
break
}
}
if !destMatched {
return false
}
@@ -1065,7 +1162,22 @@ func (m *Manager) EnableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.determineRouting()
if err := m.determineRouting(); err != nil {
return fmt.Errorf("determine routing: %w", err)
}
if m.forwarder.Load() == nil {
return nil
}
rule, err := m.blockInvalidRouted(m.wgIface)
if err != nil {
return fmt.Errorf("block invalid routed: %w", err)
}
m.blockRule = rule
return nil
}
func (m *Manager) DisableRouting() error {
@@ -1090,5 +1202,12 @@ func (m *Manager) DisableRouting() error {
log.Debug("forwarder stopped")
if m.blockRule != nil {
if err := m.deleteRouteRule(m.blockRule); err != nil {
return fmt.Errorf("delete block rule: %w", err)
}
m.blockRule = nil
}
return nil
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
)
func TestPeerACLFiltering(t *testing.T) {
@@ -188,6 +189,281 @@ func TestPeerACLFiltering(t *testing.T) {
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "Allow TCP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow UDP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "TCP packet doesn't match UDP filter with same port",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "UDP packet doesn't match TCP filter with same port",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "ICMP packet doesn't match TCP filter",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "ICMP packet doesn't match UDP filter",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "Allow TCP traffic within port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Block TCP traffic outside port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "Edge Case - Port at Range Boundary",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8100,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "UDP Port Range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 5060,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow multiple destination ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow multiple source ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
// New drop test cases
{
name: "Drop TCP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop UDP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{53}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop ICMP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop all traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolALL,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop traffic from multiple source ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop multiple destination ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop TCP traffic within port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Accept TCP traffic outside drop port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: false,
},
{
name: "Drop TCP traffic with source port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 32100,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Mixed rule - drop specific port but allow other ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
}
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
@@ -198,6 +474,28 @@ func TestPeerACLFiltering(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
// add general accept rule to test drop rule
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
rules, err := manager.AddPeerFiltering(
nil,
net.ParseIP("0.0.0.0"),
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept,
"",
)
require.NoError(t, err)
require.NotEmpty(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddPeerFiltering(
nil,
net.ParseIP(tc.ruleIP),
@@ -303,8 +601,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting())
require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter.Load())
@@ -321,7 +619,7 @@ func TestRouteACLFiltering(t *testing.T) {
type rule struct {
sources []netip.Prefix
dest netip.Prefix
dest fw.Network
proto fw.Protocol
srcPort *fw.Port
dstPort *fw.Port
@@ -347,7 +645,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept,
@@ -363,7 +661,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept,
@@ -379,7 +677,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("0.0.0.0/0"),
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept,
@@ -395,7 +693,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 53,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
dstPort: &fw.Port{Values: []uint16{53}},
action: fw.ActionAccept,
@@ -409,7 +707,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("0.0.0.0/0"),
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
proto: fw.ProtocolICMP,
action: fw.ActionAccept,
},
@@ -424,7 +722,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
@@ -440,7 +738,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
@@ -456,7 +754,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
@@ -472,7 +770,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
@@ -488,7 +786,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
srcPort: &fw.Port{Values: []uint16{12345}},
action: fw.ActionAccept,
@@ -507,7 +805,7 @@ func TestRouteACLFiltering(t *testing.T) {
netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
@@ -521,7 +819,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL,
action: fw.ActionAccept,
},
@@ -536,33 +834,13 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: true,
},
{
name: "Multiple source networks with mismatched protocol",
srcIP: "172.16.0.1",
dstIP: "192.168.1.100",
// Should not match TCP rule
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{
netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "Allow multiple destination ports",
srcIP: "100.10.0.1",
@@ -572,7 +850,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
action: fw.ActionAccept,
@@ -588,7 +866,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
action: fw.ActionAccept,
@@ -604,7 +882,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL,
srcPort: &fw.Port{Values: []uint16{12345}},
dstPort: &fw.Port{Values: []uint16{80}},
@@ -621,7 +899,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{
IsRange: true,
@@ -640,7 +918,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 7999,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{
IsRange: true,
@@ -659,7 +937,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
srcPort: &fw.Port{
IsRange: true,
@@ -678,7 +956,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
srcPort: &fw.Port{
IsRange: true,
@@ -700,7 +978,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8100,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{
IsRange: true,
@@ -719,7 +997,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 5060,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
dstPort: &fw.Port{
IsRange: true,
@@ -738,7 +1016,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL,
dstPort: &fw.Port{
IsRange: true,
@@ -757,7 +1035,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop,
@@ -773,7 +1051,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL,
action: fw.ActionDrop,
},
@@ -791,17 +1069,158 @@ func TestRouteACLFiltering(t *testing.T) {
netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionDrop,
},
shouldPass: false,
},
{
name: "Drop empty destination set",
srcIP: "172.16.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: fw.Network{Set: fw.Set{}},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "Accept TCP traffic outside drop port range",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
action: fw.ActionDrop,
},
shouldPass: true,
},
{
name: "Allow TCP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
action: fw.ActionAccept,
},
shouldPass: true,
},
{
name: "Allow UDP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
action: fw.ActionAccept,
},
shouldPass: true,
},
{
name: "TCP packet doesn't match UDP filter with same port",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "UDP packet doesn't match TCP filter with same port",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "ICMP packet doesn't match TCP filter",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "ICMP packet doesn't match UDP filter",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
action: fw.ActionAccept,
},
shouldPass: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.rule.action == fw.ActionDrop {
// add general accept rule to test drop rule
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteRouteRule(rule))
})
}
rule, err := manager.AddRouteFiltering(
nil,
tc.rule.sources,
@@ -836,7 +1255,7 @@ func TestRouteACLOrder(t *testing.T) {
name string
rules []struct {
sources []netip.Prefix
dest netip.Prefix
dest fw.Network
proto fw.Protocol
srcPort *fw.Port
dstPort *fw.Port
@@ -857,7 +1276,7 @@ func TestRouteACLOrder(t *testing.T) {
name: "Drop rules take precedence over accept",
rules: []struct {
sources []netip.Prefix
dest netip.Prefix
dest fw.Network
proto fw.Protocol
srcPort *fw.Port
dstPort *fw.Port
@@ -866,7 +1285,7 @@ func TestRouteACLOrder(t *testing.T) {
{
// Accept rule added first
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80, 443}},
action: fw.ActionAccept,
@@ -874,7 +1293,7 @@ func TestRouteACLOrder(t *testing.T) {
{
// Drop rule added second but should be evaluated first
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop,
@@ -912,7 +1331,7 @@ func TestRouteACLOrder(t *testing.T) {
name: "Multiple drop rules take precedence",
rules: []struct {
sources []netip.Prefix
dest netip.Prefix
dest fw.Network
proto fw.Protocol
srcPort *fw.Port
dstPort *fw.Port
@@ -921,14 +1340,14 @@ func TestRouteACLOrder(t *testing.T) {
{
// Accept all
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("0.0.0.0/0"),
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
proto: fw.ProtocolALL,
action: fw.ActionAccept,
},
{
// Drop specific port
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop,
@@ -936,7 +1355,7 @@ func TestRouteACLOrder(t *testing.T) {
{
// Drop different port
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionDrop,
@@ -1015,3 +1434,53 @@ func TestRouteACLOrder(t *testing.T) {
})
}
}
func TestRouteACLSet(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
// Add rule that uses the set (initially empty)
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err)
// Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
}

View File

@@ -1,7 +1,6 @@
package uspfilter
import (
"context"
"fmt"
"net"
"net/netip"
@@ -21,10 +20,11 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/management/domain"
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error
@@ -712,3 +712,203 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
})
}
}
func TestUpdateSetMerge(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
initialPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.1.0/24"),
}
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
// Update the set with initial prefixes
err = manager.UpdateSet(set, initialPrefixes)
require.NoError(t, err)
// Test initial prefixes work
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP1 := netip.MustParseAddr("10.0.0.100")
dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied")
newPrefixes := []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("10.1.0.0/24"),
}
err = manager.UpdateSet(set, newPrefixes)
require.NoError(t, err)
// Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
// Check that new prefixes are included
dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
// Verify the rule has all prefixes
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
foundRule = true
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
"Rule should have all prefixes merged")
}
}
manager.mutex.RUnlock()
require.True(t, foundRule, "Rule should be found")
}
func TestUpdateSetDeduplication(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
initialPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.0/24"), // Duplicate
}
err = manager.UpdateSet(set, initialPrefixes)
require.NoError(t, err)
// Check the internal state for deduplication
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
foundRule = true
// Should have deduplicated to 2 prefixes
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
// Check the prefixes are correct
expectedPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.1.0/24"),
}
for i, prefix := range expectedPrefixes {
require.True(t, r.destinations[i] == prefix,
"Prefix should match expected value")
}
}
}
manager.mutex.RUnlock()
require.True(t, foundRule, "Rule should be found")
// Test with overlapping prefixes of different sizes
overlappingPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/16"), // More general
netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists)
netip.MustParsePrefix("192.168.0.0/16"), // More general
netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists)
}
err = manager.UpdateSet(set, overlappingPrefixes)
require.NoError(t, err)
// Check that all prefixes are included (no deduplication of overlapping prefixes)
manager.mutex.RLock()
for _, r := range manager.routeRules {
if r.id == rule.ID() {
// Should have all 4 prefixes (2 original + 2 new more general ones)
require.Len(t, r.destinations, 4,
"Overlapping prefixes should not be deduplicated")
// Verify they're sorted correctly (more specific prefixes should come first)
prefixes := make([]string, 0, len(r.destinations))
for _, p := range r.destinations {
prefixes = append(prefixes, p.String())
}
// Check sorted order
require.Equal(t, []string{
"10.0.0.0/16",
"10.0.0.0/24",
"192.168.0.0/16",
"192.168.1.0/24",
}, prefixes, "Prefixes should be sorted")
}
}
manager.mutex.RUnlock()
// Test functionality with all prefixes
testCases := []struct {
dstIP netip.Addr
expected bool
desc string
}{
{netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"},
{netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"},
{netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"},
{netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"},
{netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"},
}
srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}

View File

@@ -150,7 +150,7 @@ func isZeros(ip net.IP) bool {
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
params.Logger = getLogger()
}
mux := &UDPMuxDefault{
@@ -455,3 +455,9 @@ func newBufferHolder(size int) *bufferHolder {
buf: make([]byte, size),
}
}
func getLogger() logging.LeveledLogger {
fac := logging.NewDefaultLoggerFactory()
//fac.Writer = log.StandardLogger().Writer()
return fac.NewLogger("ice")
}

View File

@@ -49,7 +49,7 @@ type UniversalUDPMuxParams struct {
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
params.Logger = getLogger()
}
if params.XORMappedAddrCacheTTL == 0 {
params.XORMappedAddrCacheTTL = time.Second * 25

View File

@@ -357,7 +357,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
func getFwmark() int {
if nbnet.AdvancedRouting() {
return nbnet.NetbirdFwmark
return nbnet.ControlPlaneMark
}
return 0
}

View File

@@ -18,7 +18,7 @@ func (r RuleID) ID() string {
func GenerateRouteRuleKey(
sources []netip.Prefix,
destination netip.Prefix,
destination manager.Network,
proto manager.Protocol,
sPort *manager.Port,
dPort *manager.Port,

View File

@@ -18,6 +18,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@@ -25,7 +26,7 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty")
// Manager is a ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap)
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
}
type protoMatch struct {
@@ -53,7 +54,7 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager {
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
//
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) {
d.mutex.Lock()
defer d.mutex.Unlock()
@@ -82,7 +83,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
}
@@ -176,16 +177,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
d.peerRulesPairs = newRulePairs
}
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules {
id, err := d.applyRouteACL(rule)
id, err := d.applyRouteACL(rule, dynamicResolver)
if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
} else {
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
}
@@ -208,7 +209,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) err
return nberrors.FormatErrorOrNil(merr)
}
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 {
return "", ErrSourceRangesEmpty
}
@@ -222,15 +223,9 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
sources = append(sources, source)
}
var destination netip.Prefix
if rule.IsDynamic {
destination = getDefault(sources[0])
} else {
var err error
destination, err = netip.ParsePrefix(rule.Destination)
if err != nil {
return "", fmt.Errorf("parse destination: %w", err)
}
destination, err := determineDestination(rule, dynamicResolver, sources)
if err != nil {
return "", fmt.Errorf("determine destination: %w", err)
}
protocol, err := convertToFirewallProtocol(rule.Protocol)
@@ -580,6 +575,33 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
return nil
}
func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) {
var destination firewall.Network
if rule.IsDynamic {
if dynamicResolver {
if len(rule.Domains) > 0 {
destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains))
} else {
// isDynamic is set but no domains = outdated management server
log.Warn("connected to an older version of management server (no domains in rules), using default destination")
destination.Prefix = getDefault(sources[0])
}
} else {
// client resolves DNS, we (router) don't know the destination
destination.Prefix = getDefault(sources[0])
}
return destination, nil
}
prefix, err := netip.ParsePrefix(rule.Destination)
if err != nil {
return destination, fmt.Errorf("parse destination: %w", err)
}
destination.Prefix = prefix
return destination, nil
}
func getDefault(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)

View File

@@ -1,7 +1,6 @@
package acl
import (
"context"
"net"
"testing"
@@ -15,7 +14,7 @@ import (
mgmProto "github.com/netbirdio/netbird/management/proto"
)
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
@@ -67,7 +66,7 @@ func TestDefaultManager(t *testing.T) {
acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap)
acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 2 {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
@@ -93,7 +92,7 @@ func TestDefaultManager(t *testing.T) {
},
)
acl.ApplyFiltering(networkMap)
acl.ApplyFiltering(networkMap, false)
// we should have one old and one new rule in the existed rules
if len(acl.peerRulesPairs) != 2 {
@@ -117,13 +116,13 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap)
acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return
@@ -360,7 +359,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}(fw)
acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap)
acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 3 {
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))

View File

@@ -94,12 +94,17 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
p.codeVerifier = codeVerifier
codeChallenge := createCodeChallenge(codeVerifier)
authURL := p.oAuthConfig.AuthCodeURL(
state,
params := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
)
}
if !p.providerConfig.DisablePromptLogin {
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
return AuthFlowInfo{
VerificationURIComplete: authURL,

View File

@@ -0,0 +1,49 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
func TestPromptLogin(t *testing.T) {
tt := []struct {
name string
prompt bool
}{
{"PromptLogin", true},
{"NoPromptLogin", false},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",
Scope: "openid email profile",
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true,
DisablePromptLogin: !tc.prompt,
}
pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil {
t.Fatalf("Failed to create PKCEAuthorizationFlow: %v", err)
}
authInfo, err := pkce.RequestAuthInfo(context.Background())
if err != nil {
t.Fatalf("Failed to request auth info: %v", err)
}
pattern := "prompt=login"
if tc.prompt {
require.Contains(t, authInfo.VerificationURIComplete, pattern)
} else {
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
}
})
}
}

View File

@@ -349,6 +349,25 @@ func (c *ConnectClient) Engine() *Engine {
return e
}
// GetLatestNetworkMap returns the latest network map from the engine.
func (c *ConnectClient) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
engine := c.Engine()
if engine == nil {
return nil, errors.New("engine is not initialized")
}
networkMap, err := engine.GetLatestNetworkMap()
if err != nil {
return nil, fmt.Errorf("get latest network map: %w", err)
}
if networkMap == nil {
return nil, errors.New("network map is not available")
}
return networkMap, nil
}
// Status returns the current client status
func (c *ConnectClient) Status() StatusType {
if c == nil {

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,8 @@
//go:build linux && !android
package server
package debug
import (
"archive/zip"
"bytes"
"encoding/binary"
"fmt"
@@ -14,36 +13,31 @@ import (
"github.com/google/nftables"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/proto"
)
// addFirewallRules collects and adds firewall rules to the archive
func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules")
// Collect and add iptables rules
iptablesRules, err := collectIPTablesRules()
if err != nil {
log.Warnf("Failed to collect iptables rules: %v", err)
} else {
if req.GetAnonymize() {
iptablesRules = anonymizer.AnonymizeString(iptablesRules)
if g.anonymize {
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
}
if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
log.Warnf("Failed to add iptables rules to bundle: %v", err)
}
}
// Collect and add nftables rules
nftablesRules, err := collectNFTablesRules()
if err != nil {
log.Warnf("Failed to collect nftables rules: %v", err)
} else {
if req.GetAnonymize() {
nftablesRules = anonymizer.AnonymizeString(nftablesRules)
if g.anonymize {
nftablesRules = g.anonymizer.AnonymizeString(nftablesRules)
}
if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
if err := g.addFileToZip(strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
log.Warnf("Failed to add nftables rules to bundle: %v", err)
}
}
@@ -65,16 +59,23 @@ func collectIPTablesRules() (string, error) {
builder.WriteString("\n")
}
// Then get verbose statistics for each table
// Collect ipset information
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect ipset information: %v", err)
} else {
builder.WriteString("=== ipset list output ===\n")
builder.WriteString(ipsetOutput)
builder.WriteString("\n")
}
builder.WriteString("=== iptables -v -n -L output ===\n")
// Get list of tables
tables := []string{"filter", "nat", "mangle", "raw", "security"}
for _, table := range tables {
builder.WriteString(fmt.Sprintf("*%s\n", table))
// Get verbose statistics for the entire table
stats, err := getTableStatistics(table)
if err != nil {
log.Warnf("Failed to get statistics for table %s: %v", table, err)
@@ -87,6 +88,28 @@ func collectIPTablesRules() (string, error) {
return builder.String(), nil
}
// collectIPSets collects information about ipsets
func collectIPSets() (string, error) {
cmd := exec.Command("ipset", "list")
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if strings.Contains(err.Error(), "executable file not found") {
return "", fmt.Errorf("ipset command not found: %w", err)
}
return "", fmt.Errorf("execute ipset list: %w (stderr: %s)", err, stderr.String())
}
ipsets := stdout.String()
if strings.TrimSpace(ipsets) == "" {
return "No ipsets found", nil
}
return ipsets, nil
}
// collectIPTablesSave uses iptables-save to get rule definitions
func collectIPTablesSave() (string, error) {
cmd := exec.Command("iptables-save")
@@ -182,12 +205,10 @@ func formatTables(conn *nftables.Conn, tables []*nftables.Table) string {
continue
}
// Format chains
for _, chain := range chains {
formatChain(conn, table, chain, &builder)
}
// Format sets
if sets, err := conn.GetSets(table); err != nil {
log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
} else if len(sets) > 0 {

View File

@@ -0,0 +1,7 @@
//go:build ios || android
package debug
func (g *BundleGenerator) addRoutes() error {
return nil
}

View File

@@ -0,0 +1,8 @@
//go:build !linux || android
package debug
// collectFirewallRules returns nothing on non-linux systems
func (g *BundleGenerator) addFirewallRules() error {
return nil
}

View File

@@ -0,0 +1,25 @@
//go:build !ios && !android
package debug
import (
"fmt"
"strings"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func (g *BundleGenerator) addRoutes() error {
routes, err := systemops.GetRoutesFromTable()
if err != nil {
return fmt.Errorf("get routes: %w", err)
}
// TODO: get routes including nexthop
routesContent := formatRoutes(routes, g.anonymize, g.anonymizer)
routesReader := strings.NewReader(routesContent)
if err := g.addFileToZip(routesReader, "routes.txt"); err != nil {
return fmt.Errorf("add routes file to zip: %w", err)
}
return nil
}

View File

@@ -0,0 +1,543 @@
package debug
import (
"encoding/json"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
func TestAnonymizeStateFile(t *testing.T) {
testState := map[string]json.RawMessage{
"null_state": json.RawMessage("null"),
"test_state": mustMarshal(map[string]any{
// Test simple fields
"public_ip": "203.0.113.1",
"private_ip": "192.168.1.1",
"protected_ip": "100.64.0.1",
"well_known_ip": "8.8.8.8",
"ipv6_addr": "2001:db8::1",
"private_ipv6": "fd00::1",
"domain": "test.example.com",
"uri": "stun:stun.example.com:3478",
"uri_with_ip": "turn:203.0.113.1:3478",
"netbird_domain": "device.netbird.cloud",
// Test CIDR ranges
"public_cidr": "203.0.113.0/24",
"private_cidr": "192.168.0.0/16",
"protected_cidr": "100.64.0.0/10",
"ipv6_cidr": "2001:db8::/32",
"private_ipv6_cidr": "fd00::/8",
// Test nested structures
"nested": map[string]any{
"ip": "203.0.113.2",
"domain": "nested.example.com",
"more_nest": map[string]any{
"ip": "203.0.113.3",
"domain": "deep.example.com",
},
},
// Test arrays
"string_array": []any{
"203.0.113.4",
"test1.example.com",
"test2.example.com",
},
"object_array": []any{
map[string]any{
"ip": "203.0.113.5",
"domain": "array1.example.com",
},
map[string]any{
"ip": "203.0.113.6",
"domain": "array2.example.com",
},
},
// Test multiple occurrences of same value
"duplicate_ip": "203.0.113.1", // Same as public_ip
"duplicate_domain": "test.example.com", // Same as domain
// Test URIs with various schemes
"stun_uri": "stun:stun.example.com:3478",
"turns_uri": "turns:turns.example.com:5349",
"http_uri": "http://web.example.com:80",
"https_uri": "https://secure.example.com:443",
// Test strings that might look like IPs but aren't
"not_ip": "300.300.300.300",
"partial_ip": "192.168",
"ip_like_string": "1234.5678",
// Test mixed content strings
"mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80",
// Test empty and special values
"empty_string": "",
"null_value": nil,
"numeric_value": 42,
"boolean_value": true,
}),
"route_state": mustMarshal(map[string]any{
"routes": []any{
map[string]any{
"network": "203.0.113.0/24",
"gateway": "203.0.113.1",
"domains": []any{
"route1.example.com",
"route2.example.com",
},
},
map[string]any{
"network": "2001:db8::/32",
"gateway": "2001:db8::1",
"domains": []any{
"route3.example.com",
"route4.example.com",
},
},
},
// Test map with IP/CIDR keys
"refCountMap": map[string]any{
"203.0.113.1/32": map[string]any{
"Count": 1,
"Out": map[string]any{
"IP": "192.168.0.1",
"Intf": map[string]any{
"Name": "eth0",
"Index": 1,
},
},
},
"2001:db8::1/128": map[string]any{
"Count": 1,
"Out": map[string]any{
"IP": "fe80::1",
"Intf": map[string]any{
"Name": "eth0",
"Index": 1,
},
},
},
"10.0.0.1/32": map[string]any{ // private IP should remain unchanged
"Count": 1,
"Out": map[string]any{
"IP": "192.168.0.1",
},
},
},
}),
}
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Pre-seed the domains we need to verify in the test assertions
anonymizer.AnonymizeDomain("test.example.com")
anonymizer.AnonymizeDomain("nested.example.com")
anonymizer.AnonymizeDomain("deep.example.com")
anonymizer.AnonymizeDomain("array1.example.com")
err := anonymizeStateFile(&testState, anonymizer)
require.NoError(t, err)
// Helper function to unmarshal and get nested values
var state map[string]any
err = json.Unmarshal(testState["test_state"], &state)
require.NoError(t, err)
// Test null state remains unchanged
require.Equal(t, "null", string(testState["null_state"]))
// Basic assertions
assert.NotEqual(t, "203.0.113.1", state["public_ip"])
assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged
assert.NotEqual(t, "test.example.com", state["domain"])
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
// CIDR ranges
assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"])
assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved
assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged
assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged
assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"])
assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved
// Nested structures
nested := state["nested"].(map[string]any)
assert.NotEqual(t, "203.0.113.2", nested["ip"])
assert.NotEqual(t, "nested.example.com", nested["domain"])
moreNest := nested["more_nest"].(map[string]any)
assert.NotEqual(t, "203.0.113.3", moreNest["ip"])
assert.NotEqual(t, "deep.example.com", moreNest["domain"])
// Arrays
strArray := state["string_array"].([]any)
assert.NotEqual(t, "203.0.113.4", strArray[0])
assert.NotEqual(t, "test1.example.com", strArray[1])
assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain"))
objArray := state["object_array"].([]any)
firstObj := objArray[0].(map[string]any)
assert.NotEqual(t, "203.0.113.5", firstObj["ip"])
assert.NotEqual(t, "array1.example.com", firstObj["domain"])
// Duplicate values should be anonymized consistently
assert.Equal(t, state["public_ip"], state["duplicate_ip"])
assert.Equal(t, state["domain"], state["duplicate_domain"])
// URIs
assert.NotContains(t, state["stun_uri"], "stun.example.com")
assert.NotContains(t, state["turns_uri"], "turns.example.com")
assert.NotContains(t, state["http_uri"], "web.example.com")
assert.NotContains(t, state["https_uri"], "secure.example.com")
// Non-IP strings should remain unchanged
assert.Equal(t, "300.300.300.300", state["not_ip"])
assert.Equal(t, "192.168", state["partial_ip"])
assert.Equal(t, "1234.5678", state["ip_like_string"])
// Mixed content should have IPs and domains replaced
mixedContent := state["mixed_content"].(string)
assert.NotContains(t, mixedContent, "203.0.113.1")
assert.NotContains(t, mixedContent, "test.example.com")
assert.Contains(t, mixedContent, "Server at ")
assert.Contains(t, mixedContent, " on port 80")
// Special values should remain unchanged
assert.Equal(t, "", state["empty_string"])
assert.Nil(t, state["null_value"])
assert.Equal(t, float64(42), state["numeric_value"])
assert.Equal(t, true, state["boolean_value"])
// Check route state
var routeState map[string]any
err = json.Unmarshal(testState["route_state"], &routeState)
require.NoError(t, err)
routes := routeState["routes"].([]any)
route1 := routes[0].(map[string]any)
assert.NotEqual(t, "203.0.113.0/24", route1["network"])
assert.Contains(t, route1["network"], "/24")
assert.NotEqual(t, "203.0.113.1", route1["gateway"])
domains := route1["domains"].([]any)
assert.True(t, strings.HasSuffix(domains[0].(string), ".domain"))
assert.True(t, strings.HasSuffix(domains[1].(string), ".domain"))
// Check map keys are anonymized
refCountMap := routeState["refCountMap"].(map[string]any)
hasPublicIPKey := false
hasIPv6Key := false
hasPrivateIPKey := false
for key := range refCountMap {
if strings.Contains(key, "203.0.113.1") {
hasPublicIPKey = true
}
if strings.Contains(key, "2001:db8::1") {
hasIPv6Key = true
}
if key == "10.0.0.1/32" {
hasPrivateIPKey = true
}
}
assert.False(t, hasPublicIPKey, "public IP in key should be anonymized")
assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized")
assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged")
}
func mustMarshal(v any) json.RawMessage {
data, err := json.Marshal(v)
if err != nil {
panic(err)
}
return data
}
func TestAnonymizeNetworkMap(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{
Address: "203.0.113.5",
Dns: "1.2.3.4",
Fqdn: "peer1.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."),
},
},
RemotePeers: []*mgmProto.RemotePeerConfig{
{
AllowedIps: []string{
"203.0.113.1/32",
"2001:db8:1234::1/128",
"192.168.1.1/32",
"100.64.0.1/32",
"10.0.0.1/32",
},
Fqdn: "peer2.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC2..."),
},
},
},
Routes: []*mgmProto.Route{
{
Network: "197.51.100.0/24",
Domains: []string{"prod.example.com", "staging.example.com"},
NetID: "net-123abc",
},
},
DNSConfig: &mgmProto.DNSConfig{
NameServerGroups: []*mgmProto.NameServerGroup{
{
NameServers: []*mgmProto.NameServer{
{IP: "8.8.8.8"},
{IP: "1.1.1.1"},
{IP: "203.0.113.53"},
},
Domains: []string{"example.com", "internal.example.com"},
},
},
CustomZones: []*mgmProto.CustomZone{
{
Domain: "custom.example.com",
Records: []*mgmProto.SimpleRecord{
{
Name: "www.custom.example.com",
Type: 1,
RData: "203.0.113.10",
},
{
Name: "internal.custom.example.com",
Type: 1,
RData: "192.168.1.10",
},
},
},
},
},
}
// Create anonymizer with test addresses
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Anonymize the network map
err := anonymizeNetworkMap(networkMap, anonymizer)
require.NoError(t, err)
// Test PeerConfig anonymization
peerCfg := networkMap.PeerConfig
require.NotEqual(t, "203.0.113.5", peerCfg.Address)
// Verify DNS and FQDN are properly anonymized
require.NotEqual(t, "1.2.3.4", peerCfg.Dns)
require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn)
require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain"))
// Verify SSH key is replaced
require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey)
// Test RemotePeers anonymization
remotePeer := networkMap.RemotePeers[0]
// Verify FQDN is anonymized
require.NotEqual(t, "peer2.corp.example.com", remotePeer.Fqdn)
require.True(t, strings.HasSuffix(remotePeer.Fqdn, ".domain"))
// Check that public IPs are anonymized but private IPs are preserved
for _, allowedIP := range remotePeer.AllowedIps {
ip, _, err := net.ParseCIDR(allowedIP)
require.NoError(t, err)
if ip.IsPrivate() || isInCGNATRange(ip) {
require.Contains(t, []string{
"192.168.1.1/32",
"100.64.0.1/32",
"10.0.0.1/32",
}, allowedIP)
} else {
require.NotContains(t, []string{
"203.0.113.1/32",
"2001:db8:1234::1/128",
}, allowedIP)
}
}
// Test Routes anonymization
route := networkMap.Routes[0]
require.NotEqual(t, "197.51.100.0/24", route.Network)
for _, domain := range route.Domains {
require.True(t, strings.HasSuffix(domain, ".domain"))
require.NotContains(t, domain, "example.com")
}
// Test DNS config anonymization
dnsConfig := networkMap.DNSConfig
nameServerGroup := dnsConfig.NameServerGroups[0]
// Verify well-known DNS servers are preserved
require.Equal(t, "8.8.8.8", nameServerGroup.NameServers[0].IP)
require.Equal(t, "1.1.1.1", nameServerGroup.NameServers[1].IP)
// Verify public DNS server is anonymized
require.NotEqual(t, "203.0.113.53", nameServerGroup.NameServers[2].IP)
// Verify domains are anonymized
for _, domain := range nameServerGroup.Domains {
require.True(t, strings.HasSuffix(domain, ".domain"))
require.NotContains(t, domain, "example.com")
}
// Test CustomZones anonymization
customZone := dnsConfig.CustomZones[0]
require.True(t, strings.HasSuffix(customZone.Domain, ".domain"))
require.NotContains(t, customZone.Domain, "example.com")
// Verify records are properly anonymized
for _, record := range customZone.Records {
require.True(t, strings.HasSuffix(record.Name, ".domain"))
require.NotContains(t, record.Name, "example.com")
ip := net.ParseIP(record.RData)
if ip != nil {
if !ip.IsPrivate() {
require.NotEqual(t, "203.0.113.10", record.RData)
} else {
require.Equal(t, "192.168.1.10", record.RData)
}
}
}
}
// Helper function to check if IP is in CGNAT range
func isInCGNATRange(ip net.IP) bool {
cgnat := net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
return cgnat.Contains(ip)
}
func TestAnonymizeFirewallRules(t *testing.T) {
// TODO: Add ipv6
// Example iptables-save output
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
*filter
:INPUT ACCEPT [0:0]
:FORWARD ACCEPT [0:0]
:OUTPUT ACCEPT [0:0]
-A INPUT -s 192.168.1.0/24 -j ACCEPT
-A INPUT -s 44.192.140.1/32 -j DROP
-A FORWARD -s 10.0.0.0/8 -j DROP
-A FORWARD -s 44.192.140.0/24 -d 52.84.12.34/24 -j ACCEPT
COMMIT
*nat
:PREROUTING ACCEPT [0:0]
:INPUT ACCEPT [0:0]
:OUTPUT ACCEPT [0:0]
:POSTROUTING ACCEPT [0:0]
-A POSTROUTING -s 192.168.100.0/24 -j MASQUERADE
-A PREROUTING -d 44.192.140.10/32 -p tcp -m tcp --dport 80 -j DNAT --to-destination 192.168.1.10:80
COMMIT`
// Example iptables -v -n -L output
iptablesVerbose := `Chain INPUT (policy ACCEPT 0 packets, 0 bytes)
pkts bytes target prot opt in out source destination
0 0 ACCEPT all -- * * 192.168.1.0/24 0.0.0.0/0
100 1024 DROP all -- * * 44.192.140.1 0.0.0.0/0
Chain FORWARD (policy ACCEPT 0 packets, 0 bytes)
pkts bytes target prot opt in out source destination
0 0 DROP all -- * * 10.0.0.0/8 0.0.0.0/0
25 256 ACCEPT all -- * * 44.192.140.0/24 52.84.12.34/24
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
pkts bytes target prot opt in out source destination`
// Example nftables output
nftablesRules := `table inet filter {
chain input {
type filter hook input priority filter; policy accept;
ip saddr 192.168.1.1 accept
ip saddr 44.192.140.1 drop
}
chain forward {
type filter hook forward priority filter; policy accept;
ip saddr 10.0.0.0/8 drop
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
}
}`
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Test iptables-save anonymization
anonIptablesSave := anonymizer.AnonymizeString(iptablesSave)
// Private IP addresses should remain unchanged
assert.Contains(t, anonIptablesSave, "192.168.1.0/24")
assert.Contains(t, anonIptablesSave, "10.0.0.0/8")
assert.Contains(t, anonIptablesSave, "192.168.100.0/24")
assert.Contains(t, anonIptablesSave, "192.168.1.10")
// Public IP addresses should be anonymized to the default range
assert.NotContains(t, anonIptablesSave, "44.192.140.1")
assert.NotContains(t, anonIptablesSave, "44.192.140.0/24")
assert.NotContains(t, anonIptablesSave, "52.84.12.34")
assert.Contains(t, anonIptablesSave, "198.51.100.") // Default anonymous range
// Structure should be preserved
assert.Contains(t, anonIptablesSave, "*filter")
assert.Contains(t, anonIptablesSave, ":INPUT ACCEPT [0:0]")
assert.Contains(t, anonIptablesSave, "COMMIT")
assert.Contains(t, anonIptablesSave, "-j MASQUERADE")
assert.Contains(t, anonIptablesSave, "--dport 80")
// Test iptables verbose output anonymization
anonIptablesVerbose := anonymizer.AnonymizeString(iptablesVerbose)
// Private IP addresses should remain unchanged
assert.Contains(t, anonIptablesVerbose, "192.168.1.0/24")
assert.Contains(t, anonIptablesVerbose, "10.0.0.0/8")
// Public IP addresses should be anonymized to the default range
assert.NotContains(t, anonIptablesVerbose, "44.192.140.1")
assert.NotContains(t, anonIptablesVerbose, "44.192.140.0/24")
assert.NotContains(t, anonIptablesVerbose, "52.84.12.34")
assert.Contains(t, anonIptablesVerbose, "198.51.100.") // Default anonymous range
// Structure and counters should be preserved
assert.Contains(t, anonIptablesVerbose, "Chain INPUT (policy ACCEPT 0 packets, 0 bytes)")
assert.Contains(t, anonIptablesVerbose, "100 1024 DROP")
assert.Contains(t, anonIptablesVerbose, "pkts bytes target")
// Test nftables anonymization
anonNftables := anonymizer.AnonymizeString(nftablesRules)
// Private IP addresses should remain unchanged
assert.Contains(t, anonNftables, "192.168.1.1")
assert.Contains(t, anonNftables, "10.0.0.0/8")
// Public IP addresses should be anonymized to the default range
assert.NotContains(t, anonNftables, "44.192.140.1")
assert.NotContains(t, anonNftables, "44.192.140.0/24")
assert.NotContains(t, anonNftables, "52.84.12.34")
assert.Contains(t, anonNftables, "198.51.100.") // Default anonymous range
// Structure should be preserved
assert.Contains(t, anonNftables, "table inet filter {")
assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
}

View File

@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
continue
}
listOfDomains = append(listOfDomains, dConf.Domain)
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
}
return listOfDomains
}

View File

@@ -75,12 +75,7 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
}
// First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break
}
}
c.removeEntry(origPattern, priority)
// Check if handler implements SubdomainMatcher interface
matchSubdomains := false
@@ -133,30 +128,20 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
pattern = dns.Fqdn(pattern)
c.removeEntry(pattern, priority)
}
func (c *HandlerChain) removeEntry(pattern string, priority int) {
// Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return
break
}
}
}
// HasHandlers returns true if there are any handlers remaining for the given pattern
func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
pattern = strings.ToLower(dns.Fqdn(pattern))
for _, entry := range c.handlers {
if strings.EqualFold(entry.Pattern, pattern) {
return true
}
}
return false
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return

View File

@@ -443,14 +443,6 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
for _, handler := range handlers {
handler.AssertExpectations(t)
}
// Verify handler exists check
for priority, shouldExist := range tt.expectedCalls {
if shouldExist {
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
"Handler chain should have handlers for pattern after removing priority %d", priority)
}
}
})
}
}
@@ -470,45 +462,69 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
r := new(dns.Msg)
r.SetQuestion(testQuery, dns.TypeA)
// Keep track of mocks for the final assertion in Step 4
mocks := []*nbdns.MockSubdomainHandler{routeHandler, matchHandler, defaultHandler}
// Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state with all three handlers
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Highest priority handler (routeHandler) should be called
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
chain.ServeDNS(w, r)
chain.ServeDNS(w1, r)
routeHandler.AssertExpectations(t)
routeHandler.ExpectedCalls = nil
routeHandler.Calls = nil
matchHandler.ExpectedCalls = nil
matchHandler.Calls = nil
defaultHandler.ExpectedCalls = nil
defaultHandler.Calls = nil
// Test 2: Remove highest priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Now middle priority handler (matchHandler) should be called
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
chain.ServeDNS(w, r)
chain.ServeDNS(w2, r)
matchHandler.AssertExpectations(t)
matchHandler.ExpectedCalls = nil
matchHandler.Calls = nil
defaultHandler.ExpectedCalls = nil
defaultHandler.Calls = nil
// Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r)
chain.ServeDNS(w3, r)
defaultHandler.AssertExpectations(t)
defaultHandler.ExpectedCalls = nil
defaultHandler.Calls = nil
// Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
assert.False(t, chain.HasHandlers(testDomain))
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
for _, m := range mocks {
m.AssertNumberOfCalls(t, "ServeDNS", 0)
}
}
func TestHandlerChain_CaseSensitivity(t *testing.T) {
@@ -830,3 +846,165 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
})
}
}
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
tests := []struct {
name string
addPattern string
removePattern string
queryPattern string
shouldBeRemoved bool
description string
}{
{
name: "exact same pattern",
addPattern: "example.com.",
removePattern: "example.com.",
queryPattern: "example.com.",
shouldBeRemoved: true,
description: "Adding and removing with identical patterns",
},
{
name: "case difference",
addPattern: "Example.Com.",
removePattern: "EXAMPLE.COM.",
queryPattern: "example.com.",
shouldBeRemoved: true,
description: "Adding with mixed case, removing with uppercase",
},
{
name: "reversed case difference",
addPattern: "EXAMPLE.ORG.",
removePattern: "example.org.",
queryPattern: "example.org.",
shouldBeRemoved: true,
description: "Adding with uppercase, removing with lowercase",
},
{
name: "add wildcard, remove wildcard",
addPattern: "*.example.com.",
removePattern: "*.example.com.",
queryPattern: "sub.example.com.",
shouldBeRemoved: true,
description: "Adding and removing with identical wildcard patterns",
},
{
name: "add wildcard, remove transformed pattern",
addPattern: "*.example.net.",
removePattern: "example.net.",
queryPattern: "sub.example.net.",
shouldBeRemoved: false,
description: "Adding with wildcard, removing with non-wildcard pattern",
},
{
name: "add transformed pattern, remove wildcard",
addPattern: "example.io.",
removePattern: "*.example.io.",
queryPattern: "example.io.",
shouldBeRemoved: false,
description: "Adding with non-wildcard pattern, removing with wildcard pattern",
},
{
name: "trailing dot difference",
addPattern: "example.dev",
removePattern: "example.dev.",
queryPattern: "example.dev.",
shouldBeRemoved: true,
description: "Adding without trailing dot, removing with trailing dot",
},
{
name: "reversed trailing dot difference",
addPattern: "example.app.",
removePattern: "example.app",
queryPattern: "example.app.",
shouldBeRemoved: true,
description: "Adding with trailing dot, removing without trailing dot",
},
{
name: "mixed case and wildcard",
addPattern: "*.Example.Site.",
removePattern: "*.EXAMPLE.SITE.",
queryPattern: "sub.example.site.",
shouldBeRemoved: true,
description: "Adding mixed case wildcard, removing uppercase wildcard",
},
{
name: "root zone",
addPattern: ".",
removePattern: ".",
queryPattern: "random.domain.",
shouldBeRemoved: true,
description: "Adding and removing root zone",
},
{
name: "wrong domain",
addPattern: "example.com.",
removePattern: "different.com.",
queryPattern: "example.com.",
shouldBeRemoved: false,
description: "Adding one domain, trying to remove a different domain",
},
{
name: "subdomain mismatch",
addPattern: "sub.example.com.",
removePattern: "example.com.",
queryPattern: "sub.example.com.",
shouldBeRemoved: false,
description: "Adding subdomain, trying to remove parent domain",
},
{
name: "parent domain mismatch",
addPattern: "example.com.",
removePattern: "sub.example.com.",
queryPattern: "example.com.",
shouldBeRemoved: false,
description: "Adding parent domain, trying to remove subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handler := &nbdns.MockHandler{}
r := new(dns.Msg)
r.SetQuestion(tt.queryPattern, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// First verify no handler is called before adding any
chain.ServeDNS(w, r)
handler.AssertNotCalled(t, "ServeDNS")
// Add handler
chain.AddHandler(tt.addPattern, handler, nbdns.PriorityDefault)
// Verify handler is called after adding
handler.On("ServeDNS", mock.Anything, r).Once()
chain.ServeDNS(w, r)
handler.AssertExpectations(t)
// Reset mock for the next test
handler.ExpectedCalls = nil
// Remove handler
chain.RemoveHandler(tt.removePattern, nbdns.PriorityDefault)
// Set up expectations based on whether removal should succeed
if !tt.shouldBeRemoved {
handler.On("ServeDNS", mock.Anything, r).Once()
}
// Test if handler is still called after removal attempt
chain.ServeDNS(w, r)
if tt.shouldBeRemoved {
handler.AssertNotCalled(t, "ServeDNS",
"Handler should not be called after successful removal with pattern %q",
tt.removePattern)
} else {
handler.AssertExpectations(t)
handler.ExpectedCalls = nil
}
})
}
}

View File

@@ -5,6 +5,8 @@ import (
"net/netip"
"strings"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns"
)
@@ -12,8 +14,8 @@ import (
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
const (
ipv4ReverseZone = ".in-addr.arpa"
ipv6ReverseZone = ".ip6.arpa"
ipv4ReverseZone = ".in-addr.arpa."
ipv6ReverseZone = ".ip6.arpa."
)
type hostManager interface {
@@ -103,7 +105,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
for _, domain := range nsConfig.Domains {
config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(domain, "."),
Domain: strings.ToLower(dns.Fqdn(domain)),
MatchOnly: !nsConfig.SearchDomainsEnabled,
})
}
@@ -112,7 +114,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
for _, customZone := range dnsConfig.CustomZones {
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(customZone.Domain, "."),
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
MatchOnly: matchOnly,
})
}

View File

@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
continue
}
if dConf.MatchOnly {
matchDomains = append(matchDomains, dConf.Domain)
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, "."))
continue
}
searchDomains = append(searchDomains, dConf.Domain)
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
}
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)

View File

@@ -17,9 +17,12 @@ import (
var (
userenv = syscall.NewLazyDLL("userenv.dll")
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
)
const (
@@ -97,9 +100,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
continue
}
if !dConf.MatchOnly {
searchDomains = append(searchDomains, dConf.Domain)
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, "."))
}
matchDomains = append(matchDomains, "."+dConf.Domain)
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
}
if len(matchDomains) != 0 {
@@ -116,6 +119,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return fmt.Errorf("update search domains: %w", err)
}
if err := r.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil
}
@@ -184,6 +191,26 @@ func (r *registryConfigurator) string() string {
return "registry"
}
func (r *registryConfigurator) flushDNSCache() error {
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
defer func() {
if rec := recover(); rec != nil {
log.Errorf("Recovered from panic in flushDNSCache: %v", rec)
}
}()
ret, _, err := dnsFlushResolverCacheFn.Call()
if ret == 0 {
if err != nil && !errors.Is(err, syscall.Errno(0)) {
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
}
return fmt.Errorf("DnsFlushResolverCache failed")
}
log.Info("flushed DNS cache")
return nil
}
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
return fmt.Errorf("update search domains: %w", err)
@@ -236,6 +263,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err)
}
if err := r.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil
}

View File

@@ -71,6 +71,12 @@ func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
value, found := d.records.Load(key)
if !found {
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
r.Question[0].Qtype = dns.TypeCNAME
return d.lookupRecords(r)
}
return nil
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
)
// MockServer is the mock instance of a dns server
@@ -13,17 +14,17 @@ type MockServer struct {
InitializeFunc func() error
StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func([]string, dns.Handler, int)
DeregisterHandlerFunc func([]string, int)
RegisterHandlerFunc func(domain.List, dns.Handler, int)
DeregisterHandlerFunc func(domain.List, int)
}
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
if m.RegisterHandlerFunc != nil {
m.RegisterHandlerFunc(domains, handler, priority)
}
}
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
func (m *MockServer) DeregisterHandler(domains domain.List, priority int) {
if m.DeregisterHandlerFunc != nil {
m.DeregisterHandlerFunc(domains, priority)
}

View File

@@ -13,7 +13,6 @@ import (
"github.com/godbus/dbus/v5"
"github.com/hashicorp/go-version"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -126,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
continue
}
if dConf.MatchOnly {
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain))
matchDomains = append(matchDomains, "~."+dConf.Domain)
continue
}
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain))
searchDomains = append(searchDomains, dConf.Domain)
}
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic

View File

@@ -6,11 +6,13 @@ import (
"fmt"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/listener"
@@ -18,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
cProto "github.com/netbirdio/netbird/client/proto"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
)
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
@@ -32,8 +35,8 @@ type IosDnsManager interface {
// Server is a dns server interface
type Server interface {
RegisterHandler(domains []string, handler dns.Handler, priority int)
DeregisterHandler(domains []string, priority int)
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
DeregisterHandler(domains domain.List, priority int)
Initialize() error
Stop()
DnsIP() string
@@ -65,6 +68,7 @@ type DefaultServer struct {
previousConfigHash uint64
currentConfig HostDNSConfig
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
// permanent related properties
permanent bool
@@ -164,13 +168,15 @@ func newDefaultServer(
stateManager *statemanager.Manager,
disableSys bool,
) *DefaultServer {
handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{
ctx: ctx,
ctxCancel: stop,
disableSys: disableSys,
service: dnsService,
handlerChain: NewHandlerChain(),
handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap),
localResolver: &localResolver{
registeredMap: make(registrationMap),
@@ -181,14 +187,26 @@ func newDefaultServer(
hostsDNSHolder: newHostsDNSHolder(),
}
// register with root zone, handler chain takes care of the routing
dnsService.RegisterMux(".", handlerChain)
return defaultServer
}
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
// RegisterHandler registers a handler for the given domains with the given priority.
// Any previously registered handler for the same domain and priority will be replaced.
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
s.mux.Lock()
defer s.mux.Unlock()
s.registerHandler(domains, handler, priority)
s.registerHandler(domains.ToPunycodeList(), handler, priority)
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
for _, domain := range domains {
// convert to zone with simple ref counter
s.extraDomains[toZone(domain)]++
}
s.applyHostConfig()
}
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
@@ -200,15 +218,23 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
continue
}
s.handlerChain.AddHandler(domain, handler, priority)
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
}
}
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
// DeregisterHandler deregisters the handler for the given domains with the given priority.
func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
s.mux.Lock()
defer s.mux.Unlock()
s.deregisterHandler(domains, priority)
s.deregisterHandler(domains.ToPunycodeList(), priority)
for _, domain := range domains {
zone := toZone(domain)
s.extraDomains[zone]--
if s.extraDomains[zone] <= 0 {
delete(s.extraDomains, zone)
}
}
s.applyHostConfig()
}
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
@@ -221,11 +247,6 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
}
s.handlerChain.RemoveHandler(domain, priority)
// Only deregister from service if no handlers remain
if !s.handlerChain.HasHandlers(domain) {
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
}
}
}
@@ -286,6 +307,8 @@ func (s *DefaultServer) Stop() {
}
s.service.Stop()
maps.Clear(s.extraDomains)
}
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
@@ -390,7 +413,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// is the service should be Disabled, we stop the listener or fake resolver
// and proceed with a regular update to clean up the handlers and records
if update.ServiceEnable {
_ = s.service.Listen()
if err := s.service.Listen(); err != nil {
log.Errorf("failed to start DNS service: %v", err)
}
} else if !s.permanent {
s.service.Stop()
}
@@ -413,17 +438,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
hostUpdate := s.currentConfig
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
hostUpdate.RouteAll = false
s.currentConfig.RouteAll = false
}
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
log.Error(err)
s.handleErrNoGroupaAll(err)
}
s.applyHostConfig()
go func() {
// persist dns state right away
@@ -441,6 +462,43 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
return nil
}
func (s *DefaultServer) applyHostConfig() {
if s.hostManager == nil {
return
}
// prevent reapplying config if we're shutting down
if s.ctx.Err() != nil {
return
}
config := s.currentConfig
existingDomains := make(map[string]struct{})
for _, d := range config.Domains {
existingDomains[d.Domain] = struct{}{}
}
// add extra domains only if they're not already in the config
for domain := range s.extraDomains {
domainStr := domain.PunycodeString()
if _, exists := existingDomains[domainStr]; !exists {
config.Domains = append(config.Domains, DomainConfig{
Domain: domainStr,
MatchOnly: true,
})
}
}
log.Debugf("extra match domains: %v", s.extraDomains)
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err)
s.handleErrNoGroupaAll(err)
}
}
func (s *DefaultServer) handleErrNoGroupaAll(err error) {
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
return
@@ -690,10 +748,7 @@ func (s *DefaultServer) upstreamCallbacks(
}
}
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
s.handleErrNoGroupaAll(err)
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
}
s.applyHostConfig()
go func() {
if err := s.stateManager.PersistState(s.ctx); err != nil {
@@ -728,12 +783,7 @@ func (s *DefaultServer) upstreamCallbacks(
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
}
if s.hostManager != nil {
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
s.handleErrNoGroupaAll(err)
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
}
}
s.applyHostConfig()
s.updateNSState(nsGroup, nil, true)
}
@@ -836,3 +886,13 @@ func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain
return result
}
func toZone(d domain.Domain) domain.Domain {
return domain.Domain(
nbdns.NormalizeZone(
dns.Fqdn(
strings.ToLower(d.PunycodeString()),
),
),
)
}

View File

@@ -29,16 +29,17 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/domain"
)
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type mocWGIface struct {
filter device.PacketFilter
}
func (w *mocWGIface) Name() string {
panic("implement me")
return "utun2301"
}
func (w *mocWGIface) Address() wgaddr.Address {
@@ -1448,3 +1449,497 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
})
}
}
func TestExtraDomains(t *testing.T) {
tests := []struct {
name string
initialConfig nbdns.Config
registerDomains []domain.List
deregisterDomains []domain.List
finalConfig nbdns.Config
expectedDomains []string
expectedMatchOnly []string
applyHostConfigCall int
}{
{
name: "Register domains before config update",
registerDomains: []domain.List{
{"extra1.example.com", "extra2.example.com"},
},
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
},
expectedDomains: []string{
"config.example.com.",
"extra1.example.com.",
"extra2.example.com.",
},
expectedMatchOnly: []string{
"extra1.example.com.",
"extra2.example.com.",
},
applyHostConfigCall: 2,
},
{
name: "Register domains after config update",
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
},
registerDomains: []domain.List{
{"extra1.example.com", "extra2.example.com"},
},
expectedDomains: []string{
"config.example.com.",
"extra1.example.com.",
"extra2.example.com.",
},
expectedMatchOnly: []string{
"extra1.example.com.",
"extra2.example.com.",
},
applyHostConfigCall: 2,
},
{
name: "Register overlapping domains",
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
{Domain: "overlap.example.com"},
},
},
registerDomains: []domain.List{
{"extra.example.com", "overlap.example.com"},
},
expectedDomains: []string{
"config.example.com.",
"overlap.example.com.",
"extra.example.com.",
},
expectedMatchOnly: []string{
"extra.example.com.",
},
applyHostConfigCall: 2,
},
{
name: "Register and deregister domains",
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
},
registerDomains: []domain.List{
{"extra1.example.com", "extra2.example.com"},
{"extra3.example.com", "extra4.example.com"},
},
deregisterDomains: []domain.List{
{"extra1.example.com", "extra3.example.com"},
},
expectedDomains: []string{
"config.example.com.",
"extra2.example.com.",
"extra4.example.com.",
},
expectedMatchOnly: []string{
"extra2.example.com.",
"extra4.example.com.",
},
applyHostConfigCall: 4,
},
{
name: "Register domains with ref counter",
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
},
registerDomains: []domain.List{
{"extra.example.com", "duplicate.example.com"},
{"other.example.com", "duplicate.example.com"},
},
deregisterDomains: []domain.List{
{"duplicate.example.com"},
},
expectedDomains: []string{
"config.example.com.",
"extra.example.com.",
"other.example.com.",
"duplicate.example.com.",
},
expectedMatchOnly: []string{
"extra.example.com.",
"other.example.com.",
"duplicate.example.com.",
},
applyHostConfigCall: 4,
},
{
name: "Config update with new domains after registration",
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
},
registerDomains: []domain.List{
{"extra.example.com", "duplicate.example.com"},
},
finalConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
{Domain: "newconfig.example.com"},
},
},
expectedDomains: []string{
"config.example.com.",
"newconfig.example.com.",
"extra.example.com.",
"duplicate.example.com.",
},
expectedMatchOnly: []string{
"extra.example.com.",
"duplicate.example.com.",
},
applyHostConfigCall: 3,
},
{
name: "Deregister domain that is part of customZones",
initialConfig: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
{Domain: "protected.example.com"},
},
},
registerDomains: []domain.List{
{"extra.example.com", "protected.example.com"},
},
deregisterDomains: []domain.List{
{"protected.example.com"},
},
expectedDomains: []string{
"extra.example.com.",
"config.example.com.",
"protected.example.com.",
},
expectedMatchOnly: []string{
"extra.example.com.",
},
applyHostConfigCall: 3,
},
{
name: "Register domain that is part of nameserver group",
initialConfig: nbdns.Config{
ServiceEnable: true,
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"ns.example.com", "overlap.ns.example.com"},
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
},
},
},
registerDomains: []domain.List{
{"extra.example.com", "overlap.ns.example.com"},
},
expectedDomains: []string{
"ns.example.com.",
"overlap.ns.example.com.",
"extra.example.com.",
},
expectedMatchOnly: []string{
"ns.example.com.",
"overlap.ns.example.com.",
"extra.example.com.",
},
applyHostConfigCall: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedConfigs []HostDNSConfig
mockHostConfig := &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
capturedConfigs = append(capturedConfigs, config)
return nil
},
restoreHostDNSFunc: func() error {
return nil
},
supportCustomPortFunc: func() bool {
return true
},
stringFunc: func() string {
return "mock"
},
}
mockSvc := &mockService{}
server := &DefaultServer{
ctx: context.Background(),
handlerChain: NewHandlerChain(),
wgInterface: &mocWGIface{},
hostManager: mockHostConfig,
localResolver: &localResolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
}
// Apply initial configuration
if tt.initialConfig.ServiceEnable {
err := server.applyConfiguration(tt.initialConfig)
assert.NoError(t, err)
}
// Register domains
for _, domains := range tt.registerDomains {
server.RegisterHandler(domains, &MockHandler{}, PriorityDefault)
}
// Deregister domains if specified
for _, domains := range tt.deregisterDomains {
server.DeregisterHandler(domains, PriorityDefault)
}
// Apply final configuration if specified
if tt.finalConfig.ServiceEnable {
err := server.applyConfiguration(tt.finalConfig)
assert.NoError(t, err)
}
// Verify number of calls
assert.Equal(t, tt.applyHostConfigCall, len(capturedConfigs),
"Expected %d calls to applyDNSConfig, got %d", tt.applyHostConfigCall, len(capturedConfigs))
// Get the last applied config
lastConfig := capturedConfigs[len(capturedConfigs)-1]
// Check all expected domains are present
domainMap := make(map[string]bool)
matchOnlyMap := make(map[string]bool)
for _, d := range lastConfig.Domains {
domainMap[d.Domain] = true
if d.MatchOnly {
matchOnlyMap[d.Domain] = true
}
}
// Verify expected domains
for _, d := range tt.expectedDomains {
assert.True(t, domainMap[d], "Expected domain %s not found in final config", d)
}
// Verify match-only domains
for _, d := range tt.expectedMatchOnly {
assert.True(t, matchOnlyMap[d], "Expected match-only domain %s not found in final config", d)
}
// Verify no unexpected domains
assert.Equal(t, len(tt.expectedDomains), len(domainMap), "Unexpected number of domains in final config")
assert.Equal(t, len(tt.expectedMatchOnly), len(matchOnlyMap), "Unexpected number of match-only domains in final config")
})
}
}
func TestExtraDomainsRefCounting(t *testing.T) {
mockHostConfig := &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
return nil
},
restoreHostDNSFunc: func() error {
return nil
},
supportCustomPortFunc: func() bool {
return true
},
stringFunc: func() string {
return "mock"
},
}
mockSvc := &mockService{}
server := &DefaultServer{
ctx: context.Background(),
handlerChain: NewHandlerChain(),
hostManager: mockHostConfig,
localResolver: &localResolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
}
// Register domains from different handlers with same domain
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
// Verify refcount is 2
zoneKey := toZone("shared.example.com")
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
// Deregister one handler
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
// Verify refcount is 1
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
// Deregister the other handler
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityDNSRoute)
// Verify domain is removed
_, exists := server.extraDomains[zoneKey]
assert.False(t, exists, "Domain should be removed after deregistering all handlers")
}
func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
var capturedConfig HostDNSConfig
mockHostConfig := &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
capturedConfig = config
return nil
},
restoreHostDNSFunc: func() error {
return nil
},
supportCustomPortFunc: func() bool {
return true
},
stringFunc: func() string {
return "mock"
},
}
mockSvc := &mockService{}
server := &DefaultServer{
ctx: context.Background(),
handlerChain: NewHandlerChain(),
hostManager: mockHostConfig,
localResolver: &localResolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
}
server.RegisterHandler(domain.List{"extra.example.com"}, &MockHandler{}, PriorityDefault)
initialConfig := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
}
err := server.applyConfiguration(initialConfig)
assert.NoError(t, err)
var domains []string
for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain)
}
assert.Contains(t, domains, "config.example.com.")
assert.Contains(t, domains, "extra.example.com.")
// Now apply a new configuration with overlapping domain
updatedConfig := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
{Domain: "extra.example.com"},
},
}
err = server.applyConfiguration(updatedConfig)
assert.NoError(t, err)
// Verify both domains are in config, but no duplicates
domains = []string{}
matchOnlyCount := 0
for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain)
if d.MatchOnly {
matchOnlyCount++
}
}
assert.Contains(t, domains, "config.example.com.")
assert.Contains(t, domains, "extra.example.com.")
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
// Extra domain should no longer be marked as match-only when in config
matchOnlyDomain := ""
for _, d := range capturedConfig.Domains {
if d.Domain == "extra.example.com." && d.MatchOnly {
matchOnlyDomain = d.Domain
break
}
}
assert.Empty(t, matchOnlyDomain, "Domain should not be match-only when included in config")
}
func TestDomainCaseHandling(t *testing.T) {
var capturedConfig HostDNSConfig
mockHostConfig := &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
capturedConfig = config
return nil
},
restoreHostDNSFunc: func() error {
return nil
},
supportCustomPortFunc: func() bool {
return true
},
stringFunc: func() string {
return "mock"
},
}
mockSvc := &mockService{}
server := &DefaultServer{
ctx: context.Background(),
handlerChain: NewHandlerChain(),
hostManager: mockHostConfig,
localResolver: &localResolver{},
service: mockSvc,
statusRecorder: peer.NewRecorder("test"),
extraDomains: make(map[domain.Domain]int),
}
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
config := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{Domain: "config.example.com"},
},
}
err := server.applyConfiguration(config)
assert.NoError(t, err)
var domains []string
for _, d := range capturedConfig.Domains {
domains = append(domains, d.Domain)
}
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
}

View File

@@ -11,7 +11,6 @@ import (
"time"
"github.com/godbus/dbus/v5"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
@@ -38,7 +37,6 @@ const (
type systemdDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
}
@@ -112,7 +110,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
continue
}
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: dns.Fqdn(dConf.Domain),
Domain: dConf.Domain,
MatchOnly: dConf.MatchOnly,
})
@@ -124,18 +122,19 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
}
if config.RouteAll {
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
if err != nil {
return fmt.Errorf("setting link as default dns router, failed with error: %w", err)
return fmt.Errorf("set link as default dns router: %w", err)
}
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
Domain: nbdns.RootZone,
MatchOnly: true,
})
s.routingAll = true
} else if s.routingAll {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
} else {
if err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil {
return fmt.Errorf("remove link as default dns router: %w", err)
}
}
state := &ShutdownState{
@@ -151,6 +150,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
if err != nil {
log.Error(err)
}
if err := s.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil
}
@@ -163,7 +167,8 @@ func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdD
if err != nil {
return fmt.Errorf("setting domains configuration failed with error: %w", err)
}
return s.flushCaches()
return nil
}
func (s *systemdDbusConfigurator) restoreHostDNS() error {
@@ -183,10 +188,14 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
}
return s.flushCaches()
if err := s.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil
}
func (s *systemdDbusConfigurator) flushCaches() error {
func (s *systemdDbusConfigurator) flushDNSCache() error {
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil {
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)

View File

@@ -18,14 +18,16 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
const (
UpstreamTimeout = 15 * time.Second
failsTillDeact = int32(5)
reactivatePeriod = 30 * time.Second
upstreamTimeout = 15 * time.Second
probeTimeout = 2 * time.Second
)
@@ -66,7 +68,7 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
ctx: ctx,
cancel: cancel,
domain: domain,
upstreamTimeout: upstreamTimeout,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
statusRecorder: statusRecorder,
@@ -106,9 +108,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}()
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
@@ -336,3 +336,51 @@ func (u *upstreamResolverBase) testNameserver(server string, timeout time.Durati
_, _, err := u.upstreamClient.exchange(ctx, server, r)
return err
}
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// MTU - ip + udp headers
// Note: this could be sent out on an interface that is not ours, but our MTU should always be lower.
client.UDPSize = iface.DefaultMTU - (60 + 8)
var (
rm *dns.Msg
t time.Duration
err error
)
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
if err != nil {
return nil, t, fmt.Errorf("with udp: %w", err)
}
if rm == nil || !rm.MsgHdr.Truncated {
return rm, t, nil
}
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
client.Net = "tcp"
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
return rm, t, nil
}

View File

@@ -55,7 +55,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
timeout := upstreamTimeout
timeout := UpstreamTimeout
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}

View File

@@ -34,6 +34,5 @@ func newUpstreamResolver(
}
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
upstreamExchangeClient := &dns.Client{}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
}

View File

@@ -52,7 +52,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
}
timeout := upstreamTimeout
timeout := UpstreamTimeout
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
@@ -68,7 +68,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
}
// Cannot use client.ExchangeContext because it overwrites our Dialer
return client.Exchange(r, upstream)
return ExchangeWithFallback(nil, client, r, upstream)
}
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface

View File

@@ -26,7 +26,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
name: "Should Resolve A Record",
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
timeout: upstreamTimeout,
timeout: UpstreamTimeout,
expectedAnswer: "1.1.1.1",
},
{
@@ -48,7 +48,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
cancelCTX: true,
timeout: upstreamTimeout,
timeout: UpstreamTimeout,
responseShouldBeNil: true,
},
}
@@ -122,7 +122,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
r: new(dns.Msg),
rtt: time.Millisecond,
},
upstreamTimeout: upstreamTimeout,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
}

View File

@@ -3,34 +3,53 @@ package dnsfwd
import (
"context"
"errors"
"fmt"
"math"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second
type DNSForwarder struct {
listenAddress string
ttl uint32
domains []string
listenAddress string
ttl uint32
statusRecorder *peer.Status
dnsServer *dns.Server
mux *dns.ServeMux
mutex sync.RWMutex
fwdEntries []*ForwarderEntry
firewall firewall.Manager
}
func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder {
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
listenAddress: listenAddress,
ttl: ttl,
firewall: firewall,
statusRecorder: statusRecorder,
}
}
func (f *DNSForwarder) Listen(domains []string) error {
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
mux := dns.NewServeMux()
@@ -42,23 +61,35 @@ func (f *DNSForwarder) Listen(domains []string) error {
f.dnsServer = dnsServer
f.mux = mux
f.UpdateDomains(domains)
f.UpdateDomains(entries)
return dnsServer.ListenAndServe()
}
func (f *DNSForwarder) UpdateDomains(domains []string) {
log.Debugf("Updating domains from %v to %v", f.domains, domains)
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock()
defer f.mutex.Unlock()
for _, d := range f.domains {
f.mux.HandleRemove(d)
if f.mux == nil {
log.Debug("DNS mux is nil, skipping domain update")
f.fwdEntries = entries
return
}
newDomains := filterDomains(domains)
oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString())
}
newDomains := filterDomains(entries)
for _, d := range newDomains {
f.mux.HandleFunc(d, f.handleDNSQuery)
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQuery)
}
f.domains = newDomains
f.fwdEntries = entries
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
}
func (f *DNSForwarder) Close(ctx context.Context) error {
@@ -72,48 +103,116 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
if len(query.Question) == 0 {
return
}
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass)
question := query.Question[0]
domain := question.Name
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
question.Name, question.Qtype, question.Qclass)
domain := strings.ToLower(question.Name)
resp := query.SetReply(query)
var network string
switch question.Qtype {
case dns.TypeA:
network = "ip4"
case dns.TypeAAAA:
network = "ip6"
default:
// TODO: Handle other types
ips, err := net.LookupIP(domain)
if err != nil {
var dnsErr *net.DNSError
switch {
case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound {
// Pass through NXDOMAIN
resp.Rcode = dns.RcodeNameError
}
if dnsErr.Server != "" {
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err)
}
default:
resp.Rcode = dns.RcodeServerFailure
log.Warnf(errResolveFailed, domain, err)
}
resp.Rcode = dns.RcodeNotImplemented
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
log.Errorf("failed to write DNS response: %v", err)
}
return
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
if err != nil {
f.handleDNSError(w, resp, domain, err)
return
}
f.updateInternalState(domain, ips)
f.addIPsToResponse(resp, domain, ips)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) {
var prefixes []netip.Prefix
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
if mostSpecificResId != "" {
for _, ip := range ips {
var prefix netip.Prefix
if ip.Is4() {
prefix = netip.PrefixFrom(ip, 32)
} else {
prefix = netip.PrefixFrom(ip, 128)
}
prefixes = append(prefixes, prefix)
f.statusRecorder.AddResolvedIPLookupEntry(prefix, mostSpecificResId)
}
}
if f.firewall != nil {
f.updateFirewall(matchingEntries, prefixes)
}
}
func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixes []netip.Prefix) {
var merr *multierror.Error
for _, entry := range matchingEntries {
if err := f.firewall.UpdateSet(entry.Set, prefixes); err != nil {
merr = multierror.Append(merr, fmt.Errorf("update set for domain=%s: %w", entry.Domain, err))
}
}
if merr != nil {
log.Errorf("failed to update firewall sets (%d/%d): %v",
len(merr.Errors),
len(matchingEntries),
nberrors.FormatErrorOrNil(merr))
}
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) {
var dnsErr *net.DNSError
switch {
case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound {
// Pass through NXDOMAIN
resp.Rcode = dns.RcodeNameError
}
if dnsErr.Server != "" {
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err)
}
default:
resp.Rcode = dns.RcodeServerFailure
log.Warnf(errResolveFailed, domain, err)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
}
}
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
for _, ip := range ips {
var respRecord dns.RR
if ip.To4() == nil {
if ip.Is6() {
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
rr := dns.AAAA{
AAAA: ip,
AAAA: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
@@ -125,7 +224,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
} else {
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
rr := dns.A{
A: ip,
A: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
@@ -137,21 +236,55 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
}
resp.Answer = append(resp.Answer, respRecord)
}
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
// getMatchingEntries retrieves the resource IDs for a given domain.
// It returns the most specific match and all matching resource IDs.
func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*ForwarderEntry) {
var selectedResId route.ResID
var bestScore int
var matches []*ForwarderEntry
f.mutex.RLock()
defer f.mutex.RUnlock()
for _, entry := range f.fwdEntries {
var score int
pattern := entry.Domain.PunycodeString()
switch {
case strings.HasPrefix(pattern, "*."):
baseDomain := strings.TrimPrefix(pattern, "*.")
if strings.EqualFold(domain, baseDomain) || strings.HasSuffix(domain, "."+baseDomain) {
score = len(baseDomain)
matches = append(matches, entry)
}
case domain == pattern:
score = math.MaxInt
matches = append(matches, entry)
default:
continue
}
if score > bestScore {
bestScore = score
selectedResId = entry.ResID
}
}
return selectedResId, matches
}
// filterDomains returns a list of normalized domains
func filterDomains(domains []string) []string {
newDomains := make([]string, 0, len(domains))
for _, d := range domains {
if d == "" {
func filterDomains(entries []*ForwarderEntry) domain.List {
newDomains := make(domain.List, 0, len(entries))
for _, d := range entries {
if d.Domain == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, nbdns.NormalizeZone(d))
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
}
return newDomains
}

View File

@@ -0,0 +1,103 @@
package dnsfwd
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
func Test_getMatchingEntries(t *testing.T) {
testCases := []struct {
name string
storedMappings map[string]route.ResID // key: domain pattern, value: resId
queryDomain string
expectedResId route.ResID
}{
{
name: "Empty map returns empty string",
storedMappings: map[string]route.ResID{},
queryDomain: "example.com",
expectedResId: "",
},
{
name: "Exact match returns stored resId",
storedMappings: map[string]route.ResID{"example.com": "res1"},
queryDomain: "example.com",
expectedResId: "res1",
},
{
name: "Wildcard pattern matches base domain",
storedMappings: map[string]route.ResID{"*.example.com": "res2"},
queryDomain: "example.com",
expectedResId: "res2",
},
{
name: "Wildcard pattern matches subdomain",
storedMappings: map[string]route.ResID{"*.example.com": "res3"},
queryDomain: "foo.example.com",
expectedResId: "res3",
},
{
name: "Wildcard pattern does not match different domain",
storedMappings: map[string]route.ResID{"*.example.com": "res4"},
queryDomain: "foo.notexample.com",
expectedResId: "",
},
{
name: "Non-wildcard pattern does not match subdomain",
storedMappings: map[string]route.ResID{"example.com": "res5"},
queryDomain: "foo.example.com",
expectedResId: "",
},
{
name: "Exact match over overlapping wildcard",
storedMappings: map[string]route.ResID{
"*.example.com": "resWildcard",
"foo.example.com": "resExact",
},
queryDomain: "foo.example.com",
expectedResId: "resExact",
},
{
name: "Overlapping wildcards: Select more specific wildcard",
storedMappings: map[string]route.ResID{
"*.example.com": "resA",
"*.sub.example.com": "resB",
},
queryDomain: "bar.sub.example.com",
expectedResId: "resB",
},
{
name: "Wildcard multi-level subdomain match",
storedMappings: map[string]route.ResID{
"*.example.com": "resMulti",
},
queryDomain: "a.b.example.com",
expectedResId: "resMulti",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fwd := &DNSForwarder{}
var entries []*ForwarderEntry
for domainPattern, resId := range tc.storedMappings {
d, err := domain.FromString(domainPattern)
require.NoError(t, err)
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: resId,
})
}
fwd.UpdateDomains(entries)
got, _ := fwd.getMatchingEntries(tc.queryDomain)
assert.Equal(t, got, tc.expectedResId)
})
}
}

View File

@@ -10,6 +10,9 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
const (
@@ -18,20 +21,29 @@ const (
dnsTTL = 60 //seconds
)
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
type ForwarderEntry struct {
Domain domain.Domain
ResID route.ResID
Set firewall.Set
}
type Manager struct {
firewall firewall.Manager
firewall firewall.Manager
statusRecorder *peer.Status
fwRules []firewall.Rule
dnsForwarder *DNSForwarder
}
func NewManager(fw firewall.Manager) *Manager {
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
return &Manager{
firewall: fw,
firewall: fw,
statusRecorder: statusRecorder,
}
}
func (m *Manager) Start(domains []string) error {
func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
log.Infof("starting DNS forwarder")
if m.dnsForwarder != nil {
return nil
@@ -41,9 +53,9 @@ func (m *Manager) Start(domains []string) error {
return err
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL)
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder)
go func() {
if err := m.dnsForwarder.Listen(domains); err != nil {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
// todo handle close error if it is exists
log.Errorf("failed to start DNS forwarder, err: %v", err)
}
@@ -52,12 +64,12 @@ func (m *Manager) Start(domains []string) error {
return nil
}
func (m *Manager) UpdateDomains(domains []string) {
func (m *Manager) UpdateDomains(entries []*ForwarderEntry) {
if m.dnsForwarder == nil {
return
}
m.dnsForwarder.UpdateDomains(domains)
m.dnsForwarder.UpdateDomains(entries)
}
func (m *Manager) Stop(ctx context.Context) error {
@@ -78,34 +90,34 @@ func (m *Manager) Stop(ctx context.Context) error {
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) allowDNSFirewall() error {
func (m *Manager) allowDNSFirewall() error {
dport := &firewall.Port{
IsRange: false,
Values: []uint16{ListenPort},
}
if h.firewall == nil {
if m.firewall == nil {
return nil
}
dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
}
h.fwRules = dnsRules
m.fwRules = dnsRules
return nil
}
func (h *Manager) dropDNSFirewall() error {
func (m *Manager) dropDNSFirewall() error {
var mErr *multierror.Error
for _, rule := range h.fwRules {
if err := h.firewall.DeletePeerRule(rule); err != nil {
for _, rule := range m.fwRules {
if err := m.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}
h.fwRules = nil
m.fwRules = nil
return nberrors.FormatErrorOrNil(mErr)
}

View File

@@ -353,7 +353,7 @@ func (e *Engine) Start() error {
// start flow manager right after interface creation
publicKey := e.config.WgPrivateKey.PublicKey()
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
e.flowManager = netflow.NewManager(e.wgInterface, publicKey[:], e.statusRecorder)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
@@ -527,7 +527,7 @@ func (e *Engine) blockLanAccess() {
if _, err := e.firewall.AddRouteFiltering(
nil,
[]netip.Prefix{v4},
network,
firewallManager.Network{Prefix: network},
firewallManager.ProtocolALL,
nil,
nil,
@@ -952,11 +952,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}
// Apply ACLs in the beginning to avoid security leaks
if e.acl != nil {
e.acl.ApplyFiltering(networkMap)
}
if e.firewall != nil {
if localipfw, ok := e.firewall.(localIpUpdater); ok {
if err := localipfw.UpdateLocalIPs(); err != nil {
@@ -965,16 +960,21 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
// DNS forwarder
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains)
// apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes())
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
if e.acl != nil {
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
}
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Ingress forward rules
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
log.Errorf("failed to update forward rules, err: %v", err)
@@ -1079,21 +1079,24 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
return routes
}
func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
var dnsRoutes []string
for _, protoRoute := range protoRoutes {
if len(protoRoute.Domains) == 0 {
func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderEntry {
var entries []*dnsfwd.ForwarderEntry
for _, route := range routes {
if len(route.Domains) == 0 {
continue
}
if protoRoute.Peer == myPubKey {
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
if route.Peer == myPubKey {
domainSet := firewallManager.NewDomainSet(route.Domains)
for _, d := range route.Domains {
entries = append(entries, &dnsfwd.ForwarderEntry{
Domain: d,
Set: domainSet,
ResID: route.GetResourceID(),
})
}
}
}
return dnsRoutes
return entries
}
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
@@ -1223,36 +1226,19 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
PreSharedKey: e.config.PreSharedKey,
}
if e.config.RosenpassEnabled && !e.config.RosenpassPermissive {
lk := []byte(e.config.WgPrivateKey.PublicKey().String())
rk := []byte(wgConfig.RemoteKey)
var keyInput []byte
if string(lk) > string(rk) {
//nolint:gocritic
keyInput = append(lk[:16], rk[:16]...)
} else {
//nolint:gocritic
keyInput = append(rk[:16], lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, err
}
wgConfig.PreSharedKey = &key
}
// randomize connection timeout
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
config := peer.ConnConfig{
Key: pubKey,
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
Timeout: timeout,
WgConfig: wgConfig,
LocalWgPort: e.config.WgPort,
RosenpassPubKey: e.getRosenpassPubKey(),
RosenpassAddr: e.getRosenpassAddr(),
Key: pubKey,
LocalKey: e.config.WgPrivateKey.PublicKey().String(),
Timeout: timeout,
WgConfig: wgConfig,
LocalWgPort: e.config.WgPort,
RosenpassConfig: peer.RosenpassConfig{
PubKey: e.getRosenpassPubKey(),
Addr: e.getRosenpassAddr(),
PermissiveMode: e.config.RosenpassPermissive,
},
ICEConfig: icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
@@ -1760,7 +1746,10 @@ func (e *Engine) GetWgAddr() net.IP {
}
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
func (e *Engine) updateDNSForwarder(
enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry,
) {
if !enabled {
if e.dnsForwardMgr == nil {
return
@@ -1771,18 +1760,18 @@ func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
return
}
if len(domains) > 0 {
log.Infof("enable domain router service for domains: %v", domains)
if len(fwdEntries) > 0 {
if e.dnsForwardMgr == nil {
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(domains); err != nil {
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
log.Infof("started domain router service with %d entries", len(fwdEntries))
} else {
log.Infof("update domain router service for domains: %v", domains)
e.dnsForwardMgr.UpdateDomains(domains)
e.dnsForwardMgr.UpdateDomains(fwdEntries)
}
} else if e.dnsForwardMgr != nil {
log.Infof("disable domain router service")

View File

@@ -49,6 +49,7 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -1400,15 +1401,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
Relay: &server.Relay{
config := &types.Config{
Stuns: []*types.Host{},
TURNConfig: &types.TURNConfig{},
Relay: &types.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &server.Host{
Signal: &types.Host{
Proto: "http",
URI: "localhost:10000",
},
@@ -1446,7 +1447,9 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
permissionsManager := permissions.NewManager(store)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
if err != nil {
return nil, "", err
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/ti-mo/netfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
nbnet "github.com/netbirdio/netbird/util/net"
)
const defaultChannelSize = 100
@@ -176,7 +177,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
srcIP := flow.TupleOrig.IP.SourceAddress
dstIP := flow.TupleOrig.IP.DestinationAddress
if !c.relevantFlow(srcIP, dstIP) {
if !c.relevantFlow(flow.Mark, srcIP, dstIP) {
return
}
@@ -193,7 +194,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
}
flowID := c.getFlowID(flow.ID)
direction := c.inferDirection(srcIP, dstIP)
direction := c.inferDirection(flow.Mark, srcIP, dstIP)
eventType := nftypes.TypeStart
eventStr := "New"
@@ -224,15 +225,14 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
}
// relevantFlow checks if the flow is related to the specified interface
func (c *ConnTrack) relevantFlow(srcIP, dstIP netip.Addr) bool {
// TODO: filter traffic by interface
wgnet := c.iface.Address().Network
if !wgnet.Contains(srcIP.AsSlice()) && !wgnet.Contains(dstIP.AsSlice()) {
return false
func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
if nbnet.IsDataPlaneMark(mark) {
return true
}
return true
// fallback if mark rules are not in place
wgnet := c.iface.Address().Network
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice())
}
// mapRxPackets maps packet counts to RX based on flow direction
@@ -282,7 +282,15 @@ func (c *ConnTrack) getFlowID(conntrackID uint32) uuid.UUID {
return uuid.NewSHA1(c.instanceID, buf[:])
}
func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction {
func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes.Direction {
switch mark {
case nbnet.DataPlaneMarkIn:
return nftypes.Ingress
case nbnet.DataPlaneMarkOut:
return nftypes.Egress
}
// fallback if marks are not set
wgaddr := c.iface.Address().IP
wgnetwork := c.iface.Address().Network
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
@@ -298,8 +306,6 @@ func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction {
case wgnetwork.Contains(dst):
// resource network -> netbird network
return nftypes.Egress
// TODO: handle site2site traffic
}
return nftypes.DirectionUnknown

View File

@@ -19,11 +19,9 @@ import (
type rcvChan chan *types.EventFields
type Logger struct {
mux sync.Mutex
ctx context.Context
cancel context.CancelFunc
enabled atomic.Bool
rcvChan atomic.Pointer[rcvChan]
cancelReceiver context.CancelFunc
cancel context.CancelFunc
statusRecorder *peer.Status
wgIfaceIPNet net.IPNet
dnsCollection atomic.Bool
@@ -31,12 +29,9 @@ type Logger struct {
Store types.Store
}
func New(ctx context.Context, statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
ctx, cancel := context.WithCancel(ctx)
return &Logger{
ctx: ctx,
cancel: cancel,
statusRecorder: statusRecorder,
wgIfaceIPNet: wgIfaceIPNet,
Store: store.NewMemoryStore(),
@@ -70,8 +65,8 @@ func (l *Logger) startReceiver() {
}
l.mux.Lock()
ctx, cancel := context.WithCancel(l.ctx)
l.cancelReceiver = cancel
ctx, cancel := context.WithCancel(context.Background())
l.cancel = cancel
l.mux.Unlock()
c := make(rcvChan, 100)
@@ -91,25 +86,25 @@ func (l *Logger) startReceiver() {
Timestamp: time.Now().UTC(),
}
var isExitNode bool
if event.Direction == types.Ingress {
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
event.SourceResourceID, isExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
}
} else if event.Direction == types.Egress {
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
event.DestResourceID, isExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
}
var isSrcExitNode bool
var isDestExitNode bool
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) {
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
}
if l.shouldStore(eventFields, isExitNode) {
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) {
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
}
if l.shouldStore(eventFields, isSrcExitNode || isDestExitNode) {
l.Store.StoreEvent(&event)
}
}
}
}
func (l *Logger) Disable() {
func (l *Logger) Close() {
l.stop()
l.Store.Close()
}
@@ -121,9 +116,9 @@ func (l *Logger) stop() {
l.enabled.Store(false)
l.mux.Lock()
if l.cancelReceiver != nil {
l.cancelReceiver()
l.cancelReceiver = nil
if l.cancel != nil {
l.cancel()
l.cancel = nil
}
l.rcvChan.Store(nil)
l.mux.Unlock()
@@ -142,11 +137,6 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
l.exitNodeCollection.Store(exitNodeCollection)
}
func (l *Logger) Close() {
l.stop()
l.cancel()
}
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
// check dns collection
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {

View File

@@ -1,7 +1,6 @@
package logger_test
import (
"context"
"net"
"testing"
"time"
@@ -13,7 +12,7 @@ import (
)
func TestStore(t *testing.T) {
logger := logger.New(context.Background(), nil, net.IPNet{})
logger := logger.New(nil, net.IPNet{})
logger.Enable()
event := types.EventFields{
@@ -40,7 +39,7 @@ func TestStore(t *testing.T) {
}
// test disable
logger.Disable()
logger.Close()
wait()
logger.StoreEvent(event)
wait()

View File

@@ -27,18 +27,18 @@ type Manager struct {
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
ctx context.Context
receiverClient *client.GRPCClient
publicKey []byte
cancel context.CancelFunc
}
// NewManager creates a new netflow manager
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
var ipNet net.IPNet
if iface != nil {
ipNet = *iface.Address().Network
}
flowLogger := logger.New(ctx, statusRecorder, ipNet)
flowLogger := logger.New(statusRecorder, ipNet)
var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
@@ -48,7 +48,6 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte
return &Manager{
logger: flowLogger,
conntrack: ct,
ctx: ctx,
publicKey: publicKey,
}
}
@@ -68,21 +67,9 @@ func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
// first make sender ready so events don't pile up
if m.needsNewClient(previous) {
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("error closing previous flow client: %v", err)
}
if err := m.resetClient(); err != nil {
return fmt.Errorf("reset client: %w", err)
}
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
if err != nil {
return fmt.Errorf("create client: %w", err)
}
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
m.receiverClient = flowClient
go m.receiveACKs(flowClient)
go m.startSender()
}
m.logger.Enable()
@@ -96,17 +83,50 @@ func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
return nil
}
func (m *Manager) resetClient() error {
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("error closing previous flow client: %v", err)
}
}
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
if err != nil {
return fmt.Errorf("create client: %w", err)
}
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
m.receiverClient = flowClient
if m.cancel != nil {
m.cancel()
}
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
go m.receiveACKs(ctx, flowClient)
go m.startSender(ctx)
return nil
}
// disableFlow stops components for flow tracking
func (m *Manager) disableFlow() error {
if m.cancel != nil {
m.cancel()
}
if m.conntrack != nil {
m.conntrack.Stop()
}
m.logger.Disable()
m.logger.Close()
if m.receiverClient != nil {
return m.receiverClient.Close()
}
return nil
}
@@ -133,17 +153,18 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
m.logger.UpdateConfig(update.DNSCollection, update.ExitNodeCollection)
changed := previous != nil && update.Enabled != previous.Enabled
if update.Enabled {
log.Infof("netflow manager enabled; starting netflow manager")
if changed {
log.Infof("netflow manager enabled; starting netflow manager")
}
return m.enableFlow(previous)
}
log.Infof("netflow manager disabled; stopping netflow manager")
err := m.disableFlow()
if err != nil {
log.Errorf("failed to disable netflow manager: %v", err)
if changed {
log.Infof("netflow manager disabled; stopping netflow manager")
}
return err
return m.disableFlow()
}
// Close cleans up all resources
@@ -151,17 +172,9 @@ func (m *Manager) Close() {
m.mux.Lock()
defer m.mux.Unlock()
if m.conntrack != nil {
m.conntrack.Close()
if err := m.disableFlow(); err != nil {
log.Warnf("failed to disable flow manager: %v", err)
}
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("failed to close receiver client: %v", err)
}
}
m.logger.Close()
}
// GetLogger returns the flow logger
@@ -169,13 +182,13 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
return m.logger
}
func (m *Manager) startSender() {
func (m *Manager) startSender(ctx context.Context) {
ticker := time.NewTicker(m.flowConfig.Interval)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
case <-ctx.Done():
return
case <-ticker.C:
events := m.logger.GetEvents()
@@ -190,8 +203,8 @@ func (m *Manager) startSender() {
}
}
func (m *Manager) receiveACKs(client *client.GRPCClient) {
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
err := client.Receive(ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
id, err := uuid.FromBytes(ack.EventId)
if err != nil {
log.Warnf("failed to convert ack event id to uuid: %v", err)

View File

@@ -0,0 +1,200 @@
package netflow
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
)
type mockIFaceMapper struct {
address wgaddr.Address
isUserspaceBind bool
}
func (m *mockIFaceMapper) Name() string {
return "wt0"
}
func (m *mockIFaceMapper) Address() wgaddr.Address {
return m.address
}
func (m *mockIFaceMapper) IsUserspaceBind() bool {
return m.isUserspaceBind
}
func TestManager_Update(t *testing.T) {
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
},
isUserspaceBind: true,
}
publicKey := []byte("test-public-key")
statusRecorder := peer.NewRecorder("")
manager := NewManager(mockIFace, publicKey, statusRecorder)
tests := []struct {
name string
config *types.FlowConfig
}{
{
name: "nil config",
config: nil,
},
{
name: "disabled config",
config: &types.FlowConfig{
Enabled: false,
},
},
{
name: "enabled config with minimal valid settings",
config: &types.FlowConfig{
Enabled: true,
URL: "https://example.com",
TokenPayload: "test-payload",
TokenSignature: "test-signature",
Interval: 30 * time.Second,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := manager.Update(tc.config)
assert.NoError(t, err)
if tc.config == nil {
return
}
require.NotNil(t, manager.flowConfig)
if tc.config.Enabled {
assert.Equal(t, tc.config.Enabled, manager.flowConfig.Enabled)
}
if tc.config.URL != "" {
assert.Equal(t, tc.config.URL, manager.flowConfig.URL)
}
if tc.config.TokenPayload != "" {
assert.Equal(t, tc.config.TokenPayload, manager.flowConfig.TokenPayload)
}
})
}
}
func TestManager_Update_TokenPreservation(t *testing.T) {
mockIFace := &mockIFaceMapper{
address: wgaddr.Address{
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.1"),
Mask: net.CIDRMask(24, 32),
},
},
isUserspaceBind: true,
}
publicKey := []byte("test-public-key")
manager := NewManager(mockIFace, publicKey, nil)
// First update with tokens
initialConfig := &types.FlowConfig{
Enabled: false,
TokenPayload: "initial-payload",
TokenSignature: "initial-signature",
}
err := manager.Update(initialConfig)
require.NoError(t, err)
// Second update without tokens should preserve them
updatedConfig := &types.FlowConfig{
Enabled: false,
URL: "https://example.com",
}
err = manager.Update(updatedConfig)
require.NoError(t, err)
// Verify tokens were preserved
assert.Equal(t, "initial-payload", manager.flowConfig.TokenPayload)
assert.Equal(t, "initial-signature", manager.flowConfig.TokenSignature)
}
func TestManager_NeedsNewClient(t *testing.T) {
manager := &Manager{}
tests := []struct {
name string
previous *types.FlowConfig
current *types.FlowConfig
expected bool
}{
{
name: "nil previous config",
previous: nil,
current: &types.FlowConfig{},
expected: true,
},
{
name: "previous disabled",
previous: &types.FlowConfig{Enabled: false},
current: &types.FlowConfig{Enabled: true},
expected: true,
},
{
name: "different URL",
previous: &types.FlowConfig{Enabled: true, URL: "old-url"},
current: &types.FlowConfig{Enabled: true, URL: "new-url"},
expected: true,
},
{
name: "different TokenPayload",
previous: &types.FlowConfig{Enabled: true, TokenPayload: "old-payload"},
current: &types.FlowConfig{Enabled: true, TokenPayload: "new-payload"},
expected: true,
},
{
name: "different TokenSignature",
previous: &types.FlowConfig{Enabled: true, TokenSignature: "old-signature"},
current: &types.FlowConfig{Enabled: true, TokenSignature: "new-signature"},
expected: true,
},
{
name: "same config",
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
expected: false,
},
{
name: "only interval changed",
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 30 * time.Second},
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 60 * time.Second},
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
manager.flowConfig = tc.current
result := manager.needsNewClient(tc.previous)
assert.Equal(t, tc.expected, result)
})
}
}

View File

@@ -10,6 +10,8 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
const ZoneID = 0x1BD0
type Protocol uint8
const (
@@ -120,9 +122,6 @@ type FlowLogger interface {
Close()
// Enable enables the flow logger receiver
Enable()
// Disable disables the flow logger receiver
Disable()
// UpdateConfig updates the flow manager configuration
UpdateConfig(dnsCollection, exitNodeCollection bool)
}

View File

@@ -60,6 +60,15 @@ type WgConfig struct {
PreSharedKey *wgtypes.Key
}
type RosenpassConfig struct {
// RosenpassPubKey is this peer's Rosenpass public key
PubKey []byte
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
Addr string
PermissiveMode bool
}
// ConnConfig is a peer Connection configuration
type ConnConfig struct {
// Key is a public key of a remote peer
@@ -73,10 +82,7 @@ type ConnConfig struct {
LocalWgPort int
// RosenpassPubKey is this peer's Rosenpass public key
RosenpassPubKey []byte
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
RosenpassAddr string
RosenpassConfig RosenpassConfig
// ICEConfig ICE protocol configuration
ICEConfig icemaker.Config
@@ -103,11 +109,14 @@ type Conn struct {
workerICE *WorkerICE
workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup
connIDRelay nbnet.ConnectionID
connIDICE nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
@@ -211,6 +220,7 @@ func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
// Close closes this peer Conn issuing a close event to the Conn closeCh
func (conn *Conn) Close() {
conn.mu.Lock()
defer conn.wgWatcherWg.Wait()
defer conn.mu.Unlock()
conn.log.Infof("close peer connection")
@@ -252,6 +262,7 @@ func (conn *Conn) Close() {
}
conn.setStatusToDisconnected()
conn.log.Infof("peer connection has been closed")
}
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
@@ -362,6 +373,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
}
conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Pause()
@@ -371,7 +383,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
wgProxy.Work()
}
if err = conn.configureWGEndpoint(ep); err != nil {
if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
return
}
@@ -404,10 +416,15 @@ func (conn *Conn) onICEStateDisconnected() {
conn.dumpState.SwitchToRelay()
conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.currentConnPriority = connPriorityRelay
} else {
conn.log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", connPriorityNone.String())
@@ -469,16 +486,22 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
}
wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err)
}
conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.setRelayedProxy(wgProxy)
@@ -502,6 +525,7 @@ func (conn *Conn) onRelayDisconnected() {
if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}
conn.currentConnPriority = connPriorityNone
}
if conn.wgProxyRelay != nil {
@@ -541,13 +565,14 @@ func (conn *Conn) listenGuardEvent(ctx context.Context) {
}
}
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error {
presharedKey := conn.presharedKey(remoteRPKey)
return conn.config.WgConfig.WgInterface.UpdatePeer(
conn.config.WgConfig.RemoteKey,
conn.config.WgConfig.AllowedIps,
defaultWgKeepAlive,
addr,
conn.config.WgConfig.PreSharedKey,
presharedKey,
)
}
@@ -768,6 +793,44 @@ func (conn *Conn) AllowedIP() netip.Addr {
return conn.config.WgConfig.AllowedIps[0].Addr()
}
func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
if conn.config.RosenpassConfig.PubKey == nil {
return conn.config.WgConfig.PreSharedKey
}
if remoteRosenpassKey == nil && conn.config.RosenpassConfig.PermissiveMode {
return conn.config.WgConfig.PreSharedKey
}
determKey, err := conn.rosenpassDetermKey()
if err != nil {
conn.log.Errorf("failed to generate Rosenpass initial key: %v", err)
return conn.config.WgConfig.PreSharedKey
}
return determKey
}
// todo: move this logic into Rosenpass package
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
lk := []byte(conn.config.LocalKey)
rk := []byte(conn.config.Key) // remote key
var keyInput []byte
if string(lk) > string(rk) {
//nolint:gocritic
keyInput = append(lk[:16], rk[:16]...)
} else {
//nolint:gocritic
keyInput = append(rk[:16], lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, err
}
return &key, nil
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@@ -2,6 +2,7 @@ package peer
import (
"context"
"fmt"
"os"
"sync"
"testing"
@@ -161,3 +162,145 @@ func TestConn_Status(t *testing.T) {
})
}
}
func TestConn_presharedKey(t *testing.T) {
conn1 := Conn{
config: ConnConfig{
Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
RosenpassConfig: RosenpassConfig{},
},
}
conn2 := Conn{
config: ConnConfig{
Key: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
LocalKey: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
RosenpassConfig: RosenpassConfig{},
},
}
tests := []struct {
conn1Permissive bool
conn1RosenpassEnabled bool
conn2Permissive bool
conn2RosenpassEnabled bool
conn1ExpectedInitialKey bool
conn2ExpectedInitialKey bool
}{
{
conn1Permissive: false,
conn1RosenpassEnabled: false,
conn2Permissive: false,
conn2RosenpassEnabled: false,
conn1ExpectedInitialKey: false,
conn2ExpectedInitialKey: false,
},
{
conn1Permissive: false,
conn1RosenpassEnabled: true,
conn2Permissive: false,
conn2RosenpassEnabled: true,
conn1ExpectedInitialKey: true,
conn2ExpectedInitialKey: true,
},
{
conn1Permissive: false,
conn1RosenpassEnabled: true,
conn2Permissive: false,
conn2RosenpassEnabled: false,
conn1ExpectedInitialKey: true,
conn2ExpectedInitialKey: false,
},
{
conn1Permissive: false,
conn1RosenpassEnabled: false,
conn2Permissive: false,
conn2RosenpassEnabled: true,
conn1ExpectedInitialKey: false,
conn2ExpectedInitialKey: true,
},
{
conn1Permissive: true,
conn1RosenpassEnabled: true,
conn2Permissive: false,
conn2RosenpassEnabled: false,
conn1ExpectedInitialKey: false,
conn2ExpectedInitialKey: false,
},
{
conn1Permissive: false,
conn1RosenpassEnabled: false,
conn2Permissive: true,
conn2RosenpassEnabled: true,
conn1ExpectedInitialKey: false,
conn2ExpectedInitialKey: false,
},
{
conn1Permissive: true,
conn1RosenpassEnabled: true,
conn2Permissive: true,
conn2RosenpassEnabled: true,
conn1ExpectedInitialKey: true,
conn2ExpectedInitialKey: true,
},
{
conn1Permissive: false,
conn1RosenpassEnabled: false,
conn2Permissive: false,
conn2RosenpassEnabled: true,
conn1ExpectedInitialKey: false,
conn2ExpectedInitialKey: true,
},
{
conn1Permissive: false,
conn1RosenpassEnabled: true,
conn2Permissive: true,
conn2RosenpassEnabled: true,
conn1ExpectedInitialKey: true,
conn2ExpectedInitialKey: true,
},
}
conn1.config.RosenpassConfig.PermissiveMode = true
for i, test := range tests {
tcase := i + 1
t.Run(fmt.Sprintf("Rosenpass test case %d", tcase), func(t *testing.T) {
conn1.config.RosenpassConfig = RosenpassConfig{}
conn2.config.RosenpassConfig = RosenpassConfig{}
if test.conn1RosenpassEnabled {
conn1.config.RosenpassConfig.PubKey = []byte("dummykey")
}
conn1.config.RosenpassConfig.PermissiveMode = test.conn1Permissive
if test.conn2RosenpassEnabled {
conn2.config.RosenpassConfig.PubKey = []byte("dummykey")
}
conn2.config.RosenpassConfig.PermissiveMode = test.conn2Permissive
conn1PresharedKey := conn1.presharedKey(conn2.config.RosenpassConfig.PubKey)
conn2PresharedKey := conn2.presharedKey(conn1.config.RosenpassConfig.PubKey)
if test.conn1ExpectedInitialKey {
if conn1PresharedKey == nil {
t.Errorf("Case %d: Expected conn1 to have a non-nil key, but got nil", tcase)
}
} else {
if conn1PresharedKey != nil {
t.Errorf("Case %d: Expected conn1 to have a nil key, but got %v", tcase, conn1PresharedKey)
}
}
// Assert conn2's key expectation
if test.conn2ExpectedInitialKey {
if conn2PresharedKey == nil {
t.Errorf("Case %d: Expected conn2 to have a non-nil key, but got nil", tcase)
}
} else {
if conn2PresharedKey != nil {
t.Errorf("Case %d: Expected conn2 to have a nil key, but got %v", tcase, conn2PresharedKey)
}
}
})
}
}

View File

@@ -154,8 +154,8 @@ func (h *Handshaker) sendOffer() error {
IceCredentials: IceCredentials{iceUFrag, icePwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassPubKey,
RosenpassAddr: h.config.RosenpassAddr,
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr,
}
addr, err := h.relay.RelayInstanceAddress()
@@ -174,8 +174,8 @@ func (h *Handshaker) sendAnswer() error {
IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassPubKey,
RosenpassAddr: h.config.RosenpassAddr,
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr,
}
addr, err := h.relay.RelayInstanceAddress()
if err == nil {

View File

@@ -4,6 +4,7 @@ import (
"time"
"github.com/pion/ice/v3"
"github.com/pion/logging"
"github.com/pion/randutil"
log "github.com/sirupsen/logrus"
@@ -35,6 +36,10 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida
log.Errorf("failed to create pion's stdnet: %s", err)
}
fac := logging.NewDefaultLoggerFactory()
//fac.Writer = log.StandardLogger().Writer()
agentConfig := &ice.AgentConfig{
MulticastDNSMode: ice.MulticastDNSModeDisabled,
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
@@ -51,6 +56,7 @@ func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candida
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
LocalUfrag: ufrag,
LocalPwd: pwd,
LoggerFactory: fac,
}
if config.DisableIPv6Discovery {

View File

@@ -2,40 +2,94 @@ package peer
import (
"net/netip"
"sort"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/route"
)
// routeEntry holds the route prefix and the corresponding resource ID.
type routeEntry struct {
prefix netip.Prefix
resourceID route.ResID
}
type routeIDLookup struct {
localMap sync.Map
remoteMap sync.Map
localRoutes []routeEntry
localLock sync.RWMutex
remoteRoutes []routeEntry
remoteLock sync.RWMutex
resolvedIPs sync.Map
}
func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) {
_, exists := r.localMap.LoadOrStore(route, resourceID)
if exists {
log.Tracef("resourceID %s already exists in local map", resourceID)
func (r *routeIDLookup) AddLocalRouteID(resourceID route.ResID, route netip.Prefix) {
r.localLock.Lock()
defer r.localLock.Unlock()
// update the resource id if the route already exists.
for i, entry := range r.localRoutes {
if entry.prefix == route {
r.localRoutes[i].resourceID = resourceID
log.Tracef("resourceID for route %v updated to %s in local routes", route, resourceID)
return
}
}
// append and sort descending by prefix bits (more specific first)
r.localRoutes = append(r.localRoutes, routeEntry{prefix: route, resourceID: resourceID})
sort.Slice(r.localRoutes, func(i, j int) bool {
return r.localRoutes[i].prefix.Bits() > r.localRoutes[j].prefix.Bits()
})
}
func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) {
r.localMap.Delete(route)
}
r.localLock.Lock()
defer r.localLock.Unlock()
func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) {
_, exists := r.remoteMap.LoadOrStore(route, resourceID)
if exists {
log.Tracef("resourceID %s already exists in remote map", resourceID)
for i, entry := range r.localRoutes {
if entry.prefix == route {
r.localRoutes = append(r.localRoutes[:i], r.localRoutes[i+1:]...)
return
}
}
}
func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) {
r.remoteMap.Delete(route)
func (r *routeIDLookup) AddRemoteRouteID(resourceID route.ResID, route netip.Prefix) {
r.remoteLock.Lock()
defer r.remoteLock.Unlock()
for i, entry := range r.remoteRoutes {
if entry.prefix == route {
r.remoteRoutes[i].resourceID = resourceID
log.Tracef("resourceID for route %v updated to %s in remote routes", route, resourceID)
return
}
}
// append and sort descending by prefix bits.
r.remoteRoutes = append(r.remoteRoutes, routeEntry{prefix: route, resourceID: resourceID})
sort.Slice(r.remoteRoutes, func(i, j int) bool {
return r.remoteRoutes[i].prefix.Bits() > r.remoteRoutes[j].prefix.Bits()
})
}
func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) {
func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) {
r.remoteLock.Lock()
defer r.remoteLock.Unlock()
for i, entry := range r.remoteRoutes {
if entry.prefix == route {
r.remoteRoutes = append(r.remoteRoutes[:i], r.remoteRoutes[i+1:]...)
return
}
}
}
func (r *routeIDLookup) AddResolvedIP(resourceID route.ResID, route netip.Prefix) {
r.resolvedIPs.Store(route.Addr(), resourceID)
}
@@ -44,37 +98,35 @@ func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) {
}
// Lookup returns the resource ID for the given IP address
// and a bool indicating if the IP is an exit node
func (r *routeIDLookup) Lookup(ip netip.Addr) (string, bool) {
var isExitNode bool
resId, ok := r.resolvedIPs.Load(ip)
if ok {
return resId.(string), false
// and a bool indicating if the IP is an exit node.
func (r *routeIDLookup) Lookup(ip netip.Addr) (route.ResID, bool) {
if res, ok := r.resolvedIPs.Load(ip); ok {
return res.(route.ResID), false
}
var resourceID string
r.localMap.Range(func(key, value interface{}) bool {
pref := key.(netip.Prefix)
if pref.Contains(ip) {
resourceID = value.(string)
isExitNode = pref.Bits() == 0
return false
var resourceID route.ResID
var isExitNode bool
r.localLock.RLock()
for _, entry := range r.localRoutes {
if entry.prefix.Contains(ip) {
resourceID = entry.resourceID
isExitNode = entry.prefix.Bits() == 0
break
}
return true
})
}
r.localLock.RUnlock()
if resourceID == "" {
r.remoteMap.Range(func(key, value interface{}) bool {
pref := key.(netip.Prefix)
if pref.Contains(ip) {
resourceID = value.(string)
isExitNode = pref.Bits() == 0
return false
r.remoteLock.RLock()
for _, entry := range r.remoteRoutes {
if entry.prefix.Contains(ip) {
resourceID = entry.resourceID
isExitNode = entry.prefix.Bits() == 0
break
}
return true
})
}
r.remoteLock.RUnlock()
}
return resourceID, isExitNode

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
)
const eventQueueSize = 10
@@ -313,7 +314,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string) error {
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
d.mux.Lock()
defer d.mux.Unlock()
@@ -581,14 +582,13 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
}
// AddLocalPeerStateRoute adds a route to the local peer state
func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
func (d *Status) AddLocalPeerStateRoute(route string, resourceId route.ResID) {
d.mux.Lock()
defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
return
if err == nil {
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
}
if d.localPeer.Routes == nil {
@@ -596,8 +596,6 @@ func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
}
d.localPeer.Routes[route] = struct{}{}
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
}
// RemoveLocalPeerStateRoute removes a route from the local peer state
@@ -606,14 +604,30 @@ func (d *Status) RemoveLocalPeerStateRoute(route string) {
defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
return
if err == nil {
d.routeIDLookup.RemoveLocalRouteID(pref)
}
delete(d.localPeer.Routes, route)
}
d.routeIDLookup.RemoveLocalRouteID(pref)
// AddResolvedIPLookupEntry adds a resolved IP lookup entry
func (d *Status) AddResolvedIPLookupEntry(prefix netip.Prefix, resourceId route.ResID) {
d.mux.Lock()
defer d.mux.Unlock()
d.routeIDLookup.AddResolvedIP(resourceId, prefix)
}
// RemoveResolvedIPLookupEntry removes a resolved IP lookup entry
func (d *Status) RemoveResolvedIPLookupEntry(route string) {
d.mux.Lock()
defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route)
if err == nil {
d.routeIDLookup.RemoveResolvedIP(pref)
}
}
// CleanLocalPeerStateRoutes cleans all routes from the local peer state
@@ -707,7 +721,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates
}
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId string) {
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId route.ResID) {
d.mux.Lock()
defer d.mux.Unlock()

View File

@@ -32,7 +32,6 @@ type WGWatcher struct {
ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
waitGroup sync.WaitGroup
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -48,24 +47,24 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
defer w.ctxLock.Unlock()
if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
w.ctxLock.Unlock()
return
}
ctx, ctxCancel := context.WithCancel(parentCtx)
w.ctx = ctx
w.ctxCancel = ctxCancel
w.ctxLock.Unlock()
initialHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read initial wg stats: %v", err)
}
w.waitGroup.Add(1)
go w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
}
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
@@ -81,13 +80,11 @@ func (w *WGWatcher) DisableWgWatcher() {
w.ctxCancel()
w.ctxCancel = nil
w.waitGroup.Wait()
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
defer w.waitGroup.Done()
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()

Some files were not shown because too many files have changed in this diff Show More