Compare commits

..

67 Commits

Author SHA1 Message Date
Pedro Costa
7b02e9c3a8 [management] base manager 2025-03-05 09:52:18 +00:00
hakansa
eee90fbbbf [client] UI Refactor Icon Paths (#3420)
[client] UI Refactor Icon Paths (#3420)
2025-03-05 09:47:29 +00:00
Viktor Liu
85aea0a030 [client] Close userspace firewall properly (#3426) 2025-03-05 09:47:27 +00:00
robertgro
e41acdb9ac [client] Add Netbird GitHub link to the client ui about sub menu (#3372) 2025-03-05 09:46:21 +00:00
Philippe Vaucher
54dc27abec [client Fix env var typo (#3415) 2025-03-05 09:46:21 +00:00
Bethuel Mmbaga
fdd7dc67c0 [management] Handle transaction error on peer deletion (#3387)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-03-05 09:46:21 +00:00
Viktor Liu
2d4fcaf186 Fix proto numbering (#3436) 2025-03-04 16:57:25 +01:00
Viktor Liu
acf172b52c Add kernel conntrack counters (#3434) 2025-03-04 16:46:03 +01:00
Viktor Liu
8c81a823fa Add flow ACL IDs (#3421) 2025-03-04 16:43:07 +01:00
Maycon Santos
619c549547 sync port forwarding 2025-03-04 16:29:59 +01:00
Maycon Santos
9a713a0987 Merge branch 'feature/port-forwarding' into feature/flow
# Conflicts:
#	go.mod
#	go.sum
2025-03-04 16:28:57 +01:00
Pascal Fischer
c4945cd565 add cleanup scheduler + metrics 2025-03-04 16:21:52 +01:00
Viktor Liu
1e10c17ecb Fix tcp state (#3431) 2025-03-04 11:19:54 +01:00
Viktor Liu
96d5190436 Add icmp type and code to forwarder flow event (#3413) 2025-02-28 21:04:07 +01:00
Viktor Liu
d19c26df06 Fix log direction (#3412) 2025-02-28 21:03:40 +01:00
Viktor Liu
36e36414d9 Fix forwarder log displaying (#3411) 2025-02-28 20:53:01 +01:00
bcmmbaga
7e69589e05 Update management-integrations
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 19:49:56 +00:00
bcmmbaga
aa613ab79a Update golang.org/x/crypto/ssh
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 19:27:46 +00:00
Viktor Liu
6ead0ff95e Fix log format 2025-02-28 20:24:23 +01:00
Viktor Liu
0db65a8984 Add routed packet drop flow (#3410) 2025-02-28 20:04:59 +01:00
Pascal Fischer
c138807e95 remove log message 2025-02-28 19:54:50 +01:00
Viktor Liu
637c0c8949 Add icmp type and code (#3409) 2025-02-28 19:16:42 +01:00
Viktor Liu
c72e13d8e6 Add conntrack flows (#3406) 2025-02-28 19:16:29 +01:00
Maycon Santos
f6d7bccfa0 Add flow client with sender/receiver (#3405)
add an initial version of receiver client and flow manager receiver and sender
2025-02-28 17:16:18 +00:00
bcmmbaga
e3ed01cafb go mod tidy
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 17:10:44 +00:00
Viktor Liu
fa748a7ec2 Add userspace flow implementation (#3393) 2025-02-28 11:08:35 +01:00
Maycon Santos
cccc615783 update flow proto package generated code 2025-02-28 03:09:09 +00:00
Maycon Santos
2021463ca0 update flow proto package name 2025-02-28 02:51:57 +00:00
Maycon Santos
f48cfd52e9 fix logger stop (#3403)
* fix logger stop

* use context to stop receiver

* update test
2025-02-28 00:28:17 +00:00
Pascal Fischer
6838f53f40 add getPeerByIp store method 2025-02-27 19:01:05 +01:00
Maycon Santos
8276236dfa Add netflow manager (#3398)
* Add netflow manager

* fix linter issues
2025-02-27 12:05:20 +00:00
Viktor Liu
994b923d56 Move proto and rename port and icmp info (#3399) 2025-02-27 12:52:33 +01:00
Viktor Liu
59e2432231 Add event proto fields (#3397) 2025-02-27 12:29:50 +01:00
Pascal Fischer
eee0d123e4 [management] add flow settings and credentials (#3389) 2025-02-27 12:17:07 +01:00
Viktor Liu
e943203ae2 Add event fields (#3390)
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2025-02-26 12:06:06 +01:00
Pedro Costa
6a775217cf rename flow proto messages 2025-02-25 16:29:54 +00:00
Maycon Santos
175674749f Add memory flow store (#3386) 2025-02-25 15:23:43 +00:00
Pascal Fischer
1e534cecf6 [management] Add flow proto (#3384) 2025-02-25 13:03:27 +01:00
Pedro Costa
aa3aa8c6a8 [management] flow proto 2025-02-25 11:22:54 +00:00
Pascal Fischer
fbdfe45c25 fix merge conflicts on management 2025-02-25 11:57:25 +01:00
Viktor Liu
81ee172db8 Fix route conflict 2025-02-25 11:44:21 +01:00
Viktor Liu
f8fd65a65f Merge branch 'main' into feature/port-forwarding 2025-02-25 11:37:52 +01:00
Bethuel Mmbaga
62b978c050 [management] Add support for tcp/udp allocations (#3381)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-25 10:11:50 +00:00
Bethuel Mmbaga
4ebf1410c6 [management] Add support to allocate same port for public and internal (#3347)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-21 11:16:24 +03:00
Viktor Liu
630edf2480 Remove unused var 2025-02-20 13:24:37 +01:00
Viktor Liu
ea469d28d7 Merge branch 'main' into feature/port-forwarding 2025-02-20 13:24:05 +01:00
Pascal Fischer
597f1d47b8 fix management test suite 2025-02-20 13:08:18 +01:00
Viktor Liu
fcc96417f9 Merge branch 'main' into feature/port-forwarding 2025-02-20 11:45:30 +01:00
Viktor Liu
8755211a60 Merge branch 'main' into feature/port-forwarding 2025-02-20 11:39:06 +01:00
Pascal Fischer
e6d4653b08 [management] add cloud tag to get ingress ports api spec (#3300)
* fix tag for get endpoint

* update labels
2025-02-12 16:11:54 +01:00
Zoltan Papp
eb69f2de78 Fix nil pointer exception when load empty list and try to cast it (#3282) 2025-02-06 10:28:42 +01:00
Viktor Liu
206420c085 [client] Fix grouping of peer ACLs with different port ranges (#3289) 2025-02-06 10:28:42 +01:00
Christian Stewart
88a864c195 [relay] Use new upstream for nhooyr.io/websocket package (#3287)
The nhooyr.io/websocket package was renamed to github.com/coder/websocket when
the project was transferred to "coder" as the new maintainer.

Use the new import path and update go.mod and go.sum accordingly.

Signed-off-by: Christian Stewart <christian@aperture.us>
2025-02-06 10:28:42 +01:00
Pascal Fischer
a789e9e6d8 [management] fix duplication detection (#3286) 2025-02-05 21:42:09 +01:00
Viktor Liu
9930913e4e Merge branch 'main' into feature/port-forwarding 2025-02-05 18:55:59 +01:00
Viktor Liu
48675f579f Merge branch 'main' into feature/port-forwarding 2025-02-05 17:44:01 +01:00
Pascal Fischer
afec455f86 [management] copy port info (#3283) 2025-02-05 17:30:42 +01:00
Pascal Fischer
035c5d9f23 [management merge only unique entries on network map merge (#3277) 2025-02-05 16:50:45 +01:00
Viktor Liu
b2a5b29fb2 Merge branch 'main' into feature/port-forwarding 2025-02-05 10:15:37 +01:00
Bethuel Mmbaga
9ec61206c2 [management] Add support for filtering peers by name and IP (#3279)
* add peers ip and name filters

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add get peers filter

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix get account peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Extend GetAccountPeers store to support filtering by name and IP

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix get peers references

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-05 00:33:15 +03:00
Zoltan Papp
1b011a2d85 [client] Manage the IP forwarding sysctl setting in global way (#3270)
Add new package ipfwdstate that implements reference counting for IP forwarding
state management. This allows multiple usage to safely request IP forwarding
without interfering with each other.
2025-02-03 12:27:18 +01:00
Pascal Fischer
a85ea1ddb0 [manager] ingress ports manager support (#3268)
* add peers manager

* Extend peers manager to support retrieving all peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add network map calc

* move integrations interface

* update management-integrations

* merge main and fix

* go mod tidy

* [management] port forwarding add peer manager fix network map (#3264)

* [management] fix testing tools (#3265)

* Fix net.IPv4 conversion to []byte

* update test to check ipv4

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
2025-02-03 09:37:37 +01:00
Zoltán Papp
829e40d2aa Fix ingress manager unnecessary creation 2025-02-01 10:58:47 +01:00
Pascal Fischer
6344e34880 [management] renamed ingress port endpoints (#3263) 2025-02-01 00:40:33 +01:00
Pascal Fischer
a76ca8c565 Merge branch 'main' into feature/port-forwarding 2025-01-29 22:28:10 +01:00
Zoltan Papp
26693e4ea8 Feature/port forwarding client ingress (#3242)
Client-side forward handling

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>

---------

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2025-01-29 16:04:33 +01:00
Pascal Fischer
f6a71f4193 [management] add openapi specs and generate types for port forwarding proxy (#3236) 2025-01-27 17:47:40 +01:00
416 changed files with 10917 additions and 27676 deletions

View File

@@ -1,27 +0,0 @@
# More info around this file at https://www.git-town.com/configuration-file
[branches]
main = "main"
perennials = []
perennial-regex = ""
[create]
new-branch-type = "feature"
push-new-branches = false
[hosting]
dev-remote = "origin"
# platform = ""
# origin-hostname = ""
[ship]
delete-tracking-branch = false
strategy = "squash-merge"
[sync]
feature-strategy = "merge"
perennial-strategy = "rebase"
prototype-strategy = "merge"
push-hook = true
tags = true
upstream = false

View File

@@ -31,22 +31,14 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version`
**Is any other VPN software installed?**
**NetBird status -dA output:**
If yes, which one?
If applicable, add the `netbird status -dA' command output.
**Debug output**
**Do you face any (non-mobile) client issues?**
To help us resolve the problem, please attach the following debug output
netbird status -dA
As well as the file created by
netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
**Screenshots**
@@ -55,10 +47,3 @@ If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -2,10 +2,6 @@
## Issue ticket number and link
## Stack
<!-- branch-stack -->
### Checklist
- [ ] Is it a bug fix
- [ ] Is a typo/documentation fix

View File

@@ -1,21 +0,0 @@
name: Git Town
on:
pull_request:
branches:
- '**'
jobs:
git-town:
name: Display the branch stack
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: git-town/action@v1
with:
skip-single-stacks: true

View File

@@ -22,20 +22,14 @@ jobs:
with:
usesh: true
copyback: false
release: "14.2"
release: "14.1"
prepare: |
pkg install -y curl pkgconf xorg
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"
pkg install -y go pkgconf xorg
# -x - to print all executed commands
# -e - to faile on first error
run: |
set -e -x
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...

View File

@@ -146,65 +146,6 @@ jobs:
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [build-cache]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
id: go-env
run: |
echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@v4
id: cache-restore
with:
path: |
${{ steps.go-env.outputs.cache_dir }}
${{ steps.go-env.outputs.modcache_dir }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Run tests in container
env:
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
run: |
CONTAINER_GOCACHE="/root/.cache/go-build"
CONTAINER_GOMODCACHE="/go/pkg/mod"
docker run --rm \
--cap-add=NET_ADMIN \
--privileged \
-v $PWD:/app \
-w /app \
-v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \
-v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \
-e CGO_ENABLED=1 \
-e CI=true \
-e DOCKER_CI=true \
-e GOARCH=${GOARCH_TARGET} \
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
golang:1.23-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
'
test_relay:
name: "Relay / Unit"
needs: [build-cache]
@@ -238,6 +179,13 @@ jobs:
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
@@ -284,6 +232,13 @@ jobs:
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
@@ -303,7 +258,7 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
@@ -331,6 +286,13 @@ jobs:
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
@@ -352,7 +314,6 @@ jobs:
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/...
@@ -364,8 +325,8 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -392,6 +353,13 @@ jobs:
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
@@ -412,11 +380,10 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
-timeout 20m ./...
api_benchmark:
name: "Management / Benchmark (API)"
@@ -425,37 +392,10 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
run: docker network create promnet
- name: Start Prometheus Pushgateway
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
- name: Start Prometheus (for Pushgateway forwarding)
run: |
echo '
global:
scrape_interval: 15s
scrape_configs:
- job_name: "pushgateway"
static_configs:
- targets: ["pushgateway:9091"]
remote_write:
- url: ${{ secrets.GRAFANA_URL }}
basic_auth:
username: ${{ secrets.GRAFANA_USER }}
password: ${{ secrets.GRAFANA_API_KEY }}
' > prometheus.yml
docker run -d --name prometheus --network promnet \
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go
uses: actions/setup-go@v5
with:
@@ -480,6 +420,13 @@ jobs:
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
@@ -500,13 +447,11 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
api_integration_test:
@@ -516,7 +461,7 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:
@@ -544,6 +489,13 @@ jobs:
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
@@ -553,8 +505,89 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ]
runs-on: ubuntu-20.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
- name: Generate SystemOps Test bin
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
- name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
- name: Generate Engine Test bin
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
- run: chmod +x *testing.bin
- name: Run Shared Sock tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run SystemOps tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with file store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with sqlite store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
skip: go.mod,go.sum
only_warn: 1
golangci:

View File

@@ -87,25 +87,25 @@ jobs:
with:
name: release
path: dist/
retention-days: 7
retention-days: 3
- name: upload linux packages
uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 7
retention-days: 3
- name: upload windows packages
uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 7
retention-days: 3
- name: upload macos packages
uses: actions/upload-artifact@v4
with:
name: macos-packages
path: dist/netbird_darwin**
retention-days: 7
retention-days: 3
release_ui:
runs-on: ubuntu-latest

View File

@@ -178,7 +178,6 @@ jobs:
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
- name: Install modules
run: go mod tidy

View File

@@ -96,20 +96,6 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
binary: netbird-upload
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
@@ -423,52 +409,6 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-amd64
ids:
- netbird-upload
goarch: amd64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
ids:
- netbird-upload
goarch: arm64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm
ids:
- netbird-upload
goarch: arm
goarm: 6
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
@@ -535,17 +475,7 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- netbirdio/management:{{ .Version }}-debug-arm
- netbirdio/management:{{ .Version }}-debug-amd64
- name_template: netbirdio/upload:{{ .Version }}
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/upload:latest
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
brews:
- ids:
- default

View File

@@ -1,64 +1,148 @@
## Contributor License Agreement
# Contributor License Agreement
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
of the terms and conditions outlined below. The Contributor further represents that they are authorized to
complete this process as described herein.
We are incredibly thankful for the contributions we receive from the community.
We require our external contributors to sign a Contributor License Agreement ("CLA") in
order to ensure that our projects remain licensed under Free and Open Source licenses such
as BSD-3 while allowing NetBird to build a sustainable business.
NetBird is committed to having a true Open Source Software ("OSS") license for
our software. A CLA enables NetBird to safely commercialize our products
while keeping a standard OSS license with all the rights that license grants to users: the
ability to use the project in their own projects or businesses, to republish modified
source, or to completely fork the project.
This page gives a human-friendly summary of our CLA, details on why we require a CLA, how
contributors can sign our CLA, and more. You may view the full legal CLA document (below).
# Human-friendly summary
This is a human-readable summary of (and not a substitute for) the full agreement (below).
This highlights only some of key terms of the CLA. It has no legal value and you should
carefully review all the terms of the actual CLA before agreeing.
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
in commercial products.
</li>
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
license to use that patent including within commercial products. You also agree that you
have permission to grant this license.
</li>
<li>No Warranty or Support Obligations.
By making a contribution, you are not obligating yourself to provide support for the
contribution, and you are not taking on any warranty obligations or providing any
assurances about how it will perform.
</li>
The CLA does not change the terms of the standard open source license used by our software
such as BSD-3 or MIT.
You are still free to use our projects within your own projects or businesses, republish
modified source, and more.
Please reference the appropriate license for the project you're contributing to to learn
more.
# Why require a CLA?
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
products.
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
adopt our projects. At the same time, the CLA ensures that all contributions to our open source projects are licensed
under the project's respective open source license, such as BSD-3.
Requiring a CLA is a common and well-accepted practice in open source. Major open source projects require CLAs such as
Apache Software Foundation projects, Facebook projects (such as React), Google projects (including Go), Python, Django,
and more. Each of these projects remains licensed under permissive OSS licenses such as MIT, Apache, BSD, and more.
# Signing the CLA
Open a pull request ("PR") to any of our open source projects to sign the CLA. A bot will comment on the PR asking you
to sign the CLA if you haven't already.
Follow the steps given by the bot to sign the CLA. This will require you to log in with GitHub (we only request public
information from your account) and to fill in a few additional details such as your name and email address. We will only
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
require you to sign again.
# Legal Terms and Agreement
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
your own Contributions for any other purpose.
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
You reserve all right, title, and interest in and to Your Contributions.
1. Definitions.
```
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
entities that control, are controlled by, or are under common control with that entity are considered
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
```
```
"Contribution" shall mean any original work of authorship, including any modifications or additions to
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
or documentation of, any of the products owned or managed by NetBird (the "Work").
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
source code control systems, and issue tracking systems that are managed by, or on behalf of,
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
marked or otherwise designated in writing by You as "Not a Contribution."
```
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
perform, sublicense, and distribute Your Contributions and such derivative works.
## 1 Preamble
In order to clarify the IP Rights situation with regard to Contributions from any person or entity, NetBird
must have a contributor license agreement on file to be signed by each Contributor, containing the license
terms below. This license serves as protection for both the Contributor as well as NetBird and its software users;
it does not change Contributors rights to use his/her own Contributions for any other purpose.
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
such Contribution(s) was submitted. If any entity institutes patent litigation against You or any other entity (
including a cross-claim or counterclaim in a lawsuit) alleging that your Contribution, or the Work to which you have
contributed, constitutes direct or contributory patent infringement, then any patent licenses granted to that entity
under this Agreement for that Contribution or Work shall terminate as of the date such litigation is filed.
## 2 Definitions
2.1 “IP Rights” shall mean all industrial and intellectual property rights, whether registered or not registered, whether created by Contributor or acquired by Contributor from third parties, and similar rights, including (but not limited to) semiconductor property rights, design rights, copyrights (including in the form of database rights and rights to software), all neighbouring rights (Leistungsschutzrechte), trademarks, service marks, titles, internet domain names, trade names and other labelling rights, rights deriving from corresponding applications and registrations of such rights as well as any licenses (Nutzungsrechte) under and entitlements to any such intellectual and industrial property rights.
2.2 "Contribution" shall mean any original work of authorship, including any modifications or additions to an existing work, that is or previously has been intentionally Submitted by Contributor to NetBird for inclusion in, or documentation of any Work.
4. You represent that you are legally entitled to grant the above license. If your employer(s) has rights to
intellectual property that you create that includes your Contributions, you represent that you have received
permission to make Contributions on behalf of that employer, that you will have received permission from your current
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
with NetBird.
2.3 "Contributor" shall mean the copyright owner or legal entity authorized by the copyright owner that is concluding this Agreement with NetBird. For legal entities, the entity making a Contribution and all other entities that control, are controlled by, or are under common control with that entity are considered to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
2.4 "Submitted" shall mean any form of electronic, verbal, or written communication sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, NetBird for the purpose of discussing and improving the Work, but excluding communication that is marked or otherwise designated in writing by Contributor as "Not a Contribution".
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
others). You represent that Your Contribution submissions include complete details of any third-party license or
other restriction (including, but not limited to, related patents and trademarks) of which you are personally aware
and which are associated with any part of Your Contributions.
2.5 "Work" means any of the products owned or managed by NetBird, in particular, but not exclusively, software.
## 3 Licenses
3.1 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable license to reproduce by any means and in any form, in whole or in part, permanently or temporarily, the Contributions (including loading, displaying, executing, transmitting or storing works for the purpose of executing and processing data or transferring them to video, audio and other data carriers), including the right to distribute, display and present such Contributions and make them available to the public (e.g. via the internet) and to transmit and display such Contributions by any means. The license also includes the right to modify, translate, adapt, edit and otherwise alter the Contributions and to use these results in the same manner as the original Contributions and derivative works. Except for licenses in patents acc. to Sec. 3, such license refers to any IP Rights in the Contributions and derivative works. The Contributor acknowledges that NetBird is not required to credit them by name for their Contribution and agrees to waive any moral rights associated with their Contribution in relation to NetBird or its sublicensees.
6. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support.
You may provide support for free, for a fee, or not at all. Unless required by applicable law or agreed to in
writing, You provide Your Contributions on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
express or implied, including, without limitation, any warranties or conditions of TITLE, NON- INFRINGEMENT,
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
3.2 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license in the Contributions to make, have made, use, sell, offer to sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by the Contributor which are necessarily infringed by Contributors Contribution(s) alone or by combination of Contributors Contribution(s) with the Work to which such Contribution(s) was Submitted.
3.3 NetBird hereby accepts such licenses.
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
## 4 Contributors Representations
4.1 Contributor represents that Contributor is legally entitled to grant the above license. If Contributors employer has IP Rights to Contributors Contributions, Contributor represent that he/she has received permission to make Contributions on behalf of such employer, that such employer has waived such IP Rights to the Contributions of Contributor to NetBird, or that such employer has executed a separate contributor license agreement with NetBird.
4.2 Contributor represents that any Contribution is his/her original creation.
4.3 Contributor represents to his/her best knowledge that any Contribution does not violate any third party IP Rights.
4.4 Contributor represents that any Contribution submission includes complete details of any third-party license or other restriction (including, but not limited to, related patents and trademarks) of which Contributor is personally aware and which are associated with any part of the Contribution.
4.5 The Contributor represents that their Contribution does not include any work distributed under a copyleft license.
## 5 Information obligation
Contributor agrees to notify NetBird of any facts or circumstances of which Contributor become aware that would make these representations inaccurate in any respect.
## 6 Submission of Third-Party works
Should Contributor wish to submit work that is not Contributors original creation, Contributor may submit it to NetBird separately from any Contribution, identifying the complete details of its source and of any license or other restriction (including, but not limited to, related patents, trademarks, and license agreements) of which Contributor are personally aware, and conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
## 7 No Consideration
Unless compensation is mandatory under statutory law, no compensation for any license under this agreement shall be payable.
## 8 Final Provisions
8.1 Laws. This Agreement is governed by the laws of the Federal Republic of Germany.
8.2 Venue. Place of jurisdiction shall, to the extent legally permissible, be Berlin, Germany.
8.3 Severability. If any provision in this agreement is unlawful, invalid or ineffective, it shall not affect the enforceability or effectiveness of the remainder of this agreement. The parties agree to replace any unlawful, invalid or ineffective provision with a provision that comes as close as possible to the commercial intent and purpose of the original provision. This section also applies accordingly to any gaps in the contract.
8.4 Variations. Any variations, amendments or supplements to this Agreement must be in writing. This also applies to any variation of this Section 8.4.
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
representations inaccurate in any respect.

View File

@@ -12,7 +12,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<br>
@@ -29,13 +29,13 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
<br/>
</strong>
<br>
<a href="https://github.com/netbirdio/kubernetes-operator">
New: NetBird Kubernetes Operator
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
</a>
</p>
@@ -57,16 +57,16 @@
### Key features
| Connectivity | Management | Security | Automation| Platforms |
|----|----|----|----|----|
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
||||| <ul><li>- \[x] Docker</ui></li> |
| Connectivity | Management | Security | Automation | Platforms |
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
| | | | | <ul><li> - \[x] Docker </ul></li> |
### Quickstart with NetBird Cloud

View File

@@ -1,7 +0,0 @@
# For details on buf.gen.yaml configuration, visit https://buf.build/docs/configuration/v2/buf-gen-yaml/
version: v2
plugins:
- remote: buf.build/protocolbuffers/go:v1.35.1
out: .
- remote: buf.build/grpc/go:v1.5.1
out: .

View File

@@ -1,10 +0,0 @@
# For details on buf.yaml configuration, visit https://buf.build/docs/configuration/v2/buf-yaml
version: v2
modules:
- path: proto
lint:
use:
- BASIC
breaking:
use:
- FILE

View File

@@ -1,6 +1,5 @@
FROM alpine:3.21.3
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird
COPY netbird /usr/local/bin/netbird

View File

@@ -26,7 +26,7 @@ type Anonymizer struct {
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
// 198.51.100.0, 100::
// 192.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
}

View File

@@ -11,12 +11,9 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
const errCloseConnection = "Failed to close connection: %v"
@@ -87,27 +84,16 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
}()
client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag,
}
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Printf("Local file:\n%s\n", resp.GetPath())
if resp.GetUploadFailureReason() != "" {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
cmd.Println(resp.GetPath())
return nil
}
@@ -222,19 +208,23 @@ func runForDuration(cmd *cobra.Command, args []string) error {
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: debugSystemInfoFlag,
}
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
// Disable network map persistence after creating the debug bundle
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: false,
}); err != nil {
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
}
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
@@ -249,15 +239,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Printf("Local file:\n%s\n", resp.GetPath())
if resp.GetUploadFailureReason() != "" {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
cmd.Println(resp.GetPath())
return nil
}
@@ -344,34 +326,3 @@ func formatDuration(d time.Duration) string {
s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
var networkMap *mgmProto.NetworkMap
var err error
if connectClient != nil {
networkMap, err = connectClient.GetLatestNetworkMap()
if err != nil {
log.Warnf("Failed to get latest network map: %v", err)
}
}
bundleGenerator := debug.NewBundleGenerator(
debug.GeneratorDependencies{
InternalConfig: config,
StatusRecorder: recorder,
NetworkMap: networkMap,
LogFile: logFilePath,
},
debug.BundleConfig{
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
log.Errorf("Failed to generate debug bundle: %v", err)
return
}
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
}

View File

@@ -1,39 +0,0 @@
//go:build unix
package cmd
import (
"context"
"os"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
)
func SetupDebugHandler(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
) {
usr1Ch := make(chan os.Signal, 1)
signal.Notify(usr1Ch, syscall.SIGUSR1)
go func() {
for {
select {
case <-ctx.Done():
return
case <-usr1Ch:
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
go generateDebugBundle(config, recorder, connectClient, logFilePath)
}
}
}()
}

View File

@@ -1,126 +0,0 @@
package cmd
import (
"context"
"errors"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
)
const (
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
waitTimeout = 5 * time.Second
)
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
// Example usage with PowerShell:
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
// $evt.Set()
// $evt.Close()
func SetupDebugHandler(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
) {
env := os.Getenv(envListenEvent)
if env == "" {
return
}
listenEvent, err := strconv.ParseBool(env)
if err != nil {
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
return
}
if !listenEvent {
return
}
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
if err != nil {
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
return
}
// TODO: restrict access by ACL
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
if err != nil {
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
if err != nil {
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
return
}
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
} else {
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
return
}
}
if eventHandle == windows.InvalidHandle {
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
return
}
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
}
func waitForEvent(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
eventHandle windows.Handle,
) {
defer func() {
if err := windows.CloseHandle(eventHandle); err != nil {
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
}
}()
for {
if ctx.Err() != nil {
return
}
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
switch status {
case windows.WAIT_OBJECT_0:
log.Info("Received signal on debug event. Triggering debug bundle generation.")
// reset the event so it can be triggered again later (manual reset == 1)
if err := windows.ResetEvent(eventHandle); err != nil {
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
}
go generateDebugBundle(config, recorder, connectClient, logFilePath)
case uint32(windows.WAIT_TIMEOUT):
default:
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
select {
case <-time.After(5 * time.Second):
case <-ctx.Done():
return
}
}
}
}

View File

@@ -19,10 +19,6 @@ import (
"github.com/netbirdio/netbird/util"
)
func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
}
var loginCmd = &cobra.Command{
Use: "login",
Short: "login to the Netbird Management Service (first run)",
@@ -55,9 +51,6 @@ var loginCmd = &cobra.Command{
return err
}
// update host's static platform and system information
system.UpdateStaticInfo()
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
@@ -134,7 +127,7 @@ var loginCmd = &cobra.Command{
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
@@ -205,7 +198,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
}
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
@@ -219,27 +212,19 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return &tokenInfo, nil
}
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
}
if noBrowser {
cmd.Println("Use this URL to log in:\n\n" + verificationURIComplete + " " + codeMsg)
} else {
cmd.Println("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" +
verificationURIComplete + " " + codeMsg)
}
cmd.Println("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" +
verificationURIComplete + " " + codeMsg)
cmd.Println("")
if !noBrowser {
if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
}

View File

@@ -22,7 +22,6 @@ import (
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/upload-server/types"
)
const (
@@ -40,8 +39,6 @@ const (
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
)
var (
@@ -78,8 +75,6 @@ var (
debugSystemInfoFlag bool
dnsRouteInterval time.Duration
blockLANAccess bool
debugUploadBundle bool
debugUploadBundleURL string
rootCmd = &cobra.Command{
Use: "netbird",
@@ -185,9 +180,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
}
// SetupCloseHandler handles SIGTERM signal and exits with success

View File

@@ -16,17 +16,12 @@ import (
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
)
func (p *program) Start(svc service.Service) error {
// Start should not block. Do the actual work async.
log.Info("starting Netbird service") //nolint
// Collect static system and platform information
system.UpdateStaticInfo()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer()
@@ -120,7 +115,6 @@ var runCmd = &cobra.Command{
ctx, cancel := context.WithCancel(cmd.Context())
SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, logFile)
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil {

View File

@@ -6,17 +6,14 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
@@ -34,7 +31,7 @@ import (
func startTestingServices(t *testing.T) string {
t.Helper()
config := &types.Config{}
config := &mgmt.Config{}
_, err := util.ReadJson("../testdata/management.json", config)
if err != nil {
t.Fatal(err)
@@ -69,7 +66,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis
}
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
@@ -92,24 +89,14 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
if err != nil {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -32,16 +32,12 @@ const (
const (
dnsLabelsFlag = "extra-dns-labels"
noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login"
)
var (
foregroundMode bool
dnsLabels []string
dnsLabelsValidated domain.List
noBrowser bool
upCmd = &cobra.Command{
Use: "up",
@@ -69,9 +65,6 @@ func init() {
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`,
)
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
}
func upFunc(cmd *cobra.Command, args []string) error {
@@ -219,8 +212,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
}
@@ -358,7 +349,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {

View File

@@ -134,11 +134,10 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
run := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
clientErr <- err
run <- err
}
}()
@@ -148,9 +147,13 @@ func (c *Client) Start(startCtx context.Context) error {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}
return startCtx.Err()
case err := <-clientErr:
return fmt.Errorf("startup: %w", err)
case <-run:
case err := <-run:
if err != nil {
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
}
return fmt.Errorf("startup: %w", err)
}
}
c.connect = client

View File

@@ -4,13 +4,12 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() wgaddr.Address
Address() device.WGAddress
IsUserspaceBind() bool
SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice

View File

@@ -13,7 +13,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -31,7 +31,7 @@ type Manager struct {
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
Address() iface.WGAddress
IsUserspaceBind() bool
}
@@ -113,16 +113,17 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
destination netip.Prefix,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -242,14 +243,6 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.router.DeleteDNATRule(rule)
}
// UpdateSet updates the set with the given prefixes
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -10,15 +10,15 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() wgaddr.Address
AddressFunc func() iface.WGAddress
}
func (i *iFaceMock) Name() string {
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set")
}
func (i *iFaceMock) Address() wgaddr.Address {
func (i *iFaceMock) Address() iface.WGAddress {
if i.AddressFunc != nil {
return i.AddressFunc()
}
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),

View File

@@ -38,12 +38,10 @@ const (
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
matchSet = "--match-set"
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
matchSet = "--match-set"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
@@ -57,18 +55,18 @@ type ruleInfo struct {
}
type routeFilteringRuleParams struct {
Source firewall.Network
Destination firewall.Network
Sources []netip.Prefix
Destination netip.Prefix
Proto firewall.Protocol
SPort *firewall.Port
DPort *firewall.Port
Direction firewall.RuleDirection
Action firewall.Action
SetName string
}
type routeRules map[string][]string
// the ipset library currently does not support comments, so we use the name only (string)
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
type router struct {
@@ -117,10 +115,6 @@ func (r *router) init(stateManager *statemanager.Manager) error {
return fmt.Errorf("create containers: %w", err)
}
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
r.updateState()
return nil
@@ -129,7 +123,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
@@ -140,28 +134,27 @@ func (r *router) AddRouteFiltering(
return ruleKey, nil
}
var source firewall.Network
var setName string
if len(sources) > 1 {
source.Set = firewall.NewPrefixSet(sources)
} else if len(sources) > 0 {
source.Prefix = sources[0]
setName = firewall.GenerateSetName(sources)
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
}
params := routeFilteringRuleParams{
Source: source,
Sources: sources,
Destination: destination,
Proto: proto,
SPort: sPort,
DPort: dPort,
Action: action,
SetName: setName,
}
rule, err := r.genRouteRuleSpec(params, sources)
if err != nil {
return nil, fmt.Errorf("generate route rule spec: %w", err)
}
rule := genRouteFilteringRuleSpec(params)
// Insert DROP rules at the beginning, append ACCEPT rules at the end
var err error
if action == firewall.ActionDrop {
// after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
@@ -184,13 +177,17 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID()
if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err)
}
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("failed to remove ipset: %w", err)
}
}
} else {
log.Debugf("route rule %s not found", ruleKey)
@@ -201,26 +198,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return nil
}
func (r *router) decrementSetCounter(rule []string) error {
sets := r.findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule []string) []string {
var sets []string
func (r *router) findSetNameInRule(rule []string) string {
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
sets = append(sets, rule[i+3])
return rule[i+3]
}
}
return sets
return ""
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
@@ -241,8 +225,6 @@ func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
log.Debugf("Deleted unused ipset %s", setName)
return nil
}
@@ -282,14 +264,12 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
log.Errorf("%v", err)
}
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -327,10 +307,8 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} else {
log.Debugf("legacy forwarding rule %s not found", ruleKey)
}
return nil
@@ -370,16 +348,12 @@ func (r *router) Reset() error {
if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err)
}
r.rules = make(map[string][]string)
if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanupDataPlaneMark(); err != nil {
merr = multierror.Append(merr, err)
}
r.rules = make(map[string][]string)
r.updateState()
return nberrors.FormatErrorOrNil(merr)
@@ -449,57 +423,6 @@ func (r *router) createContainers() error {
return nil
}
// setupDataPlaneMark configures the fwmark for the data plane
func (r *router) setupDataPlaneMark() error {
var merr *multierror.Error
preRule := []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
} else {
r.rules[markManglePre] = preRule
}
postRule := []string{
"-o", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
} else {
r.rules[markManglePost] = postRule
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) cleanupDataPlaneMark() error {
var merr *multierror.Error
if preRule, exists := r.rules[markManglePre]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
} else {
delete(r.rules, markManglePre)
}
}
if postRule, exists := r.rules[markManglePost]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
} else {
delete(r.rules, markManglePost)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
@@ -541,7 +464,7 @@ func (r *router) insertEstablishedRule(chain string) error {
}
func (r *router) addJumpRules() error {
// Jump to nat chain
// Jump to NAT chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat postrouting jump rule: %v", err)
@@ -615,26 +538,12 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
)
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
if err != nil {
return fmt.Errorf("apply network -s: %w", err)
}
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
if err != nil {
return fmt.Errorf("apply network -d: %w", err)
}
rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
rule = append(rule,
"-s", pair.Source.String(),
"-d", pair.Destination.String(),
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
// Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
// TODO: rollback ipset counter
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
}
@@ -652,10 +561,6 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} else {
log.Debugf("marking rule %s not found", ruleKey)
}
@@ -821,21 +726,17 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) {
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string
sourceExp, err := r.applyNetwork("-s", params.Source, sources)
if err != nil {
return nil, fmt.Errorf("apply network -s: %w", err)
}
destExp, err := r.applyNetwork("-d", params.Destination, nil)
if err != nil {
return nil, fmt.Errorf("apply network -d: %w", err)
if params.SetName != "" {
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
} else if len(params.Sources) > 0 {
source := params.Sources[0]
rule = append(rule, "-s", source.String())
}
rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
rule = append(rule, "-d", params.Destination.String())
if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
@@ -845,47 +746,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
rule = append(rule, "-j", actionToStr(params.Action))
return rule, nil
}
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
direction := "src"
if flag == "-d" {
direction = "dst"
}
if network.IsSet() {
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
}
if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil
}
// nolint:nilnil
return nil, nil
}
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
var merr *multierror.Error
for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
}
}
if merr == nil {
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
}
return nberrors.FormatErrorOrNil(merr)
return rule
}
func applyPort(flag string, port *firewall.Port) []string {

View File

@@ -46,9 +46,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 5. jump rule to PRE nat chain
// 6. static outbound masquerade rule
// 7. static return masquerade rule
// 8. mangle prerouting mark rule
// 9. mangle postrouting mark rule
require.Len(t, manager.rules, 9, "should have created rules map")
require.Len(t, manager.rules, 7, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -60,8 +58,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
pair := firewall.RouterPair{
ID: "abc",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")},
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.100.0/24"),
Masquerade: true,
}
@@ -332,7 +330,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map
@@ -347,29 +345,23 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")
var source firewall.Network
if len(tt.sources) > 1 {
source.Set = firewall.NewPrefixSet(tt.sources)
} else if len(tt.sources) > 0 {
source.Prefix = tt.sources[0]
}
// Verify rule content
params := routeFilteringRuleParams{
Source: source,
Destination: firewall.Network{Prefix: tt.destination},
Sources: tt.sources,
Destination: tt.destination,
Proto: tt.proto,
SPort: tt.sPort,
DPort: tt.dPort,
Action: tt.action,
SetName: "",
}
expectedRule, err := r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec")
expectedRule := genRouteFilteringRuleSpec(params)
if tt.expectSet {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
expectedRule, err = r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec with set")
setName := firewall.GenerateSetName(tt.sources)
params.SetName = setName
expectedRule = genRouteFilteringRuleSpec(params)
// Check if the set was created
_, exists := r.ipsetCounter.Get(setName)
@@ -384,62 +376,3 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
})
}
}
func TestFindSetNameInRule(t *testing.T) {
r := &router{}
testCases := []struct {
name string
rule []string
expected []string
}{
{
name: "Basic rule with two sets",
rule: []string{
"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src",
"-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT",
},
expected: []string{"nb-2e5a2a05", "nb-349ae051"},
},
{
name: "No sets",
rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"},
expected: []string{},
},
{
name: "Multiple sets with different positions",
rule: []string{
"-m", "set", "--match-set", "set1", "src", "-p", "tcp",
"-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT",
},
expected: []string{"set1", "set-abc123"},
},
{
name: "Boundary case - sequence appears at end",
rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"},
expected: []string{"final-set"},
},
{
name: "Incomplete pattern - missing set name",
rule: []string{"-p", "tcp", "-m", "set", "--match-set"},
expected: []string{},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := r.findSets(tc.rule)
if len(result) != len(tc.expected) {
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
return
}
for i, set := range result {
if set != tc.expected[i] {
t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set)
}
}
})
}
}

View File

@@ -4,20 +4,21 @@ import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() wgaddr.Address {
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}

View File

@@ -1,10 +1,13 @@
package manager
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"net/netip"
"sort"
"strings"
log "github.com/sirupsen/logrus"
@@ -40,18 +43,6 @@ const (
// Action is the action to be taken on a rule
type Action int
// String returns the string representation of the action
func (a Action) String() string {
switch a {
case ActionAccept:
return "accept"
case ActionDrop:
return "drop"
default:
return "unknown"
}
}
const (
// ActionAccept is the action to accept a packet
ActionAccept Action = iota
@@ -59,33 +50,6 @@ const (
ActionDrop
)
// Network is a rule destination, either a set or a prefix
type Network struct {
Set Set
Prefix netip.Prefix
}
// String returns the string representation of the destination
func (d Network) String() string {
if d.Prefix.IsValid() {
return d.Prefix.String()
}
if d.IsSet() {
return d.Set.HashedName()
}
return "<invalid network>"
}
// IsSet returns true if the destination is a set
func (d Network) IsSet() bool {
return d.Set != Set{}
}
// IsPrefix returns true if the destination is a valid prefix
func (d Network) IsPrefix() bool {
return d.Prefix.IsValid()
}
// Manager is the high level abstraction of a firewall manager
//
// It declares methods which handle actions required by the
@@ -119,9 +83,10 @@ type Manager interface {
AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination Network,
destination netip.Prefix,
proto Protocol,
sPort, dPort *Port,
sPort *Port,
dPort *Port,
action Action,
) (Rule, error)
@@ -154,9 +119,6 @@ type Manager interface {
// DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error
}
func GenKey(format string, pair RouterPair) string {
@@ -191,6 +153,22 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
return nil
}
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {

View File

@@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.NewPrefixSet(prefixes1)
result2 := manager.NewPrefixSet(prefixes2)
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
@@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("10.0.0.0/8"),
}
result := manager.NewPrefixSet(prefixes)
result := manager.GenerateSetName(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName())
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
if err != nil {
t.Fatalf("Error matching regex: %v", err)
}
@@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) {
})
t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.NewPrefixSet([]netip.Prefix{})
result2 := manager.NewPrefixSet([]netip.Prefix{})
result1 := manager.GenerateSetName([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{})
if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
@@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.NewPrefixSet(prefixes1)
result2 := manager.NewPrefixSet(prefixes2)
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)

View File

@@ -1,13 +1,15 @@
package manager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type RouterPair struct {
ID route.ID
Source Network
Destination Network
Source netip.Prefix
Destination netip.Prefix
Masquerade bool
Inverse bool
}

View File

@@ -1,74 +0,0 @@
package manager
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/netip"
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
)
type Set struct {
hash [4]byte
comment string
}
// String returns the string representation of the set: hashed name and comment
func (h Set) String() string {
if h.comment == "" {
return h.HashedName()
}
return h.HashedName() + ": " + h.comment
}
// HashedName returns the string representation of the hash
func (h Set) HashedName() string {
return fmt.Sprintf(
"nb-%s",
hex.EncodeToString(h.hash[:]),
)
}
// Comment returns the comment of the set
func (h Set) Comment() string {
return h.comment
}
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
func NewPrefixSet(prefixes []netip.Prefix) Set {
// sort for consistent naming
SortPrefixes(prefixes)
hash := sha256.New()
for _, src := range prefixes {
bytes, err := src.MarshalBinary()
if err != nil {
log.Warnf("failed to marshal prefix %s: %v", src, err)
}
hash.Write(bytes)
}
var set Set
copy(set.hash[:], hash.Sum(nil)[:4])
return set
}
// NewDomainSet generates a unique name for an ipset based on the given domains.
func NewDomainSet(domains domain.List) Set {
slices.Sort(domains)
hash := sha256.New()
for _, d := range domains {
hash.Write([]byte(d.PunycodeString()))
}
set := Set{
comment: domains.SafeString(),
}
copy(set.hash[:], hash.Sum(nil)[:4])
return set
}

View File

@@ -25,10 +25,9 @@ const (
chainNameInputRules = "netbird-acl-input-rules"
// filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
@@ -463,15 +462,13 @@ func (m *AclManager) createDefaultChains() (err error) {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
// Chain is created by route manager
// TODO: move creation to a common place
m.chainPrerouting = &nftables.Chain{
Name: chainNameManglePrerouting,
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
})
m.addFwmarkToForward(chainFwFilter)

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -29,7 +29,7 @@ const (
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
Address() iface.WGAddress
IsUserspaceBind() bool
}
@@ -135,16 +135,17 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
destination netip.Prefix,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -241,7 +242,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
}
// Close closes the firewall manager
// Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -358,14 +359,6 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.router.DeleteDNATRule(rule)
}
// UpdateSet updates the set with the given prefixes
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -16,15 +16,15 @@ import (
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() wgaddr.Address
AddressFunc func() iface.WGAddress
}
func (i *iFaceMock) Name() string {
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set")
}
func (i *iFaceMock) Address() wgaddr.Address {
func (i *iFaceMock) Address() iface.WGAddress {
if i.AddressFunc != nil {
return i.AddressFunc()
}
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
@@ -289,7 +289,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
@@ -298,8 +298,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{
Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
Source: netip.MustParsePrefix("192.168.1.0/24"),
Destination: netip.MustParsePrefix("10.0.0.0/24"),
Masquerade: true,
}
err = manager.AddNatRule(pair)

View File

@@ -10,6 +10,7 @@ import (
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
@@ -43,14 +44,9 @@ const (
const refreshRulesMapError = "refresh rules map: %w"
var (
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type setInput struct {
set firewall.Set
prefixes []netip.Prefix
}
type router struct {
conn *nftables.Conn
workTable *nftables.Table
@@ -58,7 +54,7 @@ type router struct {
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
@@ -104,10 +100,6 @@ func (r *router) init(workTable *nftables.Table) error {
return fmt.Errorf("create containers: %w", err)
}
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
return nil
}
@@ -167,7 +159,7 @@ func (r *router) removeNatPreroutingRules() error {
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("unable to list tables: %v", err)
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
@@ -204,21 +196,15 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT,
})
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePostrouting,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePrerouting,
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
}
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
@@ -234,83 +220,7 @@ func (r *router) createContainers() error {
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
}
return nil
}
// setupDataPlaneMark configures the fwmark for the data plane
func (r *router) setupDataPlaneMark() error {
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
return errors.New("no mangle chains found")
}
ctNew := getCtNewExprs()
preExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
preExprs = append(preExprs, ctNew...)
preExprs = append(preExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
preNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: preExprs,
}
r.conn.AddRule(preNftRule)
postExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
postExprs = append(postExprs, ctNew...)
postExprs = append(postExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
postNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePostrouting],
Exprs: postExprs,
}
r.conn.AddRule(postNftRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
@@ -320,7 +230,7 @@ func (r *router) setupDataPlaneMark() error {
func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
@@ -335,29 +245,23 @@ func (r *router) AddRouteFiltering(
chain := r.chains[chainNameRoutingFw]
var exprs []expr.Any
var source firewall.Network
switch {
case len(sources) == 1 && sources[0].Bits() == 0:
// If it's 0.0.0.0/0, we don't need to add any source matching
case len(sources) == 1:
// If there's only one source, we can use it directly
source.Prefix = sources[0]
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
default:
// If there are multiple sources, use a set
source.Set = firewall.NewPrefixSet(sources)
// If there are multiple sources, create or get an ipset
var err error
exprs, err = r.getIpSetExprs(sources, exprs)
if err != nil {
return nil, fmt.Errorf("get ipset expressions: %w", err)
}
}
sourceExp, err := r.applyNetwork(source, sources, true)
if err != nil {
return nil, fmt.Errorf("apply source: %w", err)
}
exprs = append(exprs, sourceExp...)
destExp, err := r.applyNetwork(destination, nil, false)
if err != nil {
return nil, fmt.Errorf("apply destination: %w", err)
}
exprs = append(exprs, destExp...)
// Handle destination
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
// Handle protocol
if proto != firewall.ProtocolALL {
@@ -401,27 +305,39 @@ func (r *router) AddRouteFiltering(
rule = r.conn.AddRule(rule)
}
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
r.rules[string(ruleKey)] = rule
log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil
}
func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
set: set,
prefixes: prefixes,
})
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources)
ref, err := r.ipsetCounter.Increment(setName, sources)
if err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
}
return getIpSetExprs(ref, isSource)
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
)
return exprs, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
@@ -440,54 +356,42 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
return fmt.Errorf("delete: %w", err)
}
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("decrement ipset reference: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
if err := r.decrementSetCounter(nftRule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}
func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) {
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
prefixes := firewall.MergeIPRanges(input.prefixes)
sources = firewall.MergeIPRanges(sources)
nfset := &nftables.Set{
Name: setName,
Comment: input.set.Comment(),
Table: r.workTable,
set := &nftables.Set{
Name: setName,
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: nftables.TypeIPAddr,
}
elements := convertPrefixesToSet(prefixes)
if err := r.conn.AddSet(nfset, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return nfset, nil
}
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement
for _, prefix := range prefixes {
for _, prefix := range sources {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
@@ -503,7 +407,18 @@ func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
}
return elements
if err := r.conn.AddSet(set, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return set, nil
}
// calculateLastIP determines the last IP in a given prefix.
@@ -527,8 +442,8 @@ func uint32ToBytes(ip uint32) [4]byte {
return b
}
func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
r.conn.DelSet(nfset)
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
@@ -537,27 +452,13 @@ func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
return nil
}
func (r *router) decrementSetCounter(rule *nftables.Rule) error {
sets := r.findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule *nftables.Rule) []string {
var sets []string
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
sets = append(sets, lookup.SetName)
return lookup.SetName
}
}
return sets
return ""
}
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
@@ -599,8 +500,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
// TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
}
return nil
@@ -608,15 +508,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
op := expr.CmpOpEq
if pair.Inverse {
@@ -624,6 +517,26 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
}
exprs := []expr.Any{
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
@@ -634,9 +547,6 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
Data: ifname(r.wgIface.Name()),
},
}
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs = append(exprs, getCtNewExprs()...)
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
@@ -666,11 +576,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
}
}
// Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Chain: r.chains[chainNamePrerouting],
Exprs: exprs,
UserData: []byte(ruleKey),
})
@@ -751,15 +659,8 @@ func (r *router) addPostroutingRules() error {
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
exprs := []expr.Any{
&expr.Counter{},
@@ -768,8 +669,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
},
}
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
@@ -782,7 +682,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Exprs: exprs,
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
@@ -797,13 +697,11 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else {
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
}
return nil
@@ -1014,14 +912,12 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err)
}
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -1029,10 +925,10 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
// TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil
}
@@ -1040,19 +936,16 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else {
log.Debugf("prerouting rule %s not found", ruleKey)
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
}
return nil
@@ -1064,7 +957,7 @@ func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf(" unable to list rules: %v", err)
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
@@ -1338,54 +1231,13 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
if err != nil {
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
}
elements := convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,
setPrefixes []netip.Prefix,
isSource bool,
) ([]expr.Any, error) {
if network.IsSet() {
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
if err != nil {
return nil, fmt.Errorf("source: %w", err)
}
return exprs, nil
}
if network.IsPrefix() {
return applyPrefix(network.Prefix, isSource), nil
}
return nil, nil
}
// applyPrefix generates nftables expressions for a CIDR prefix
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32
if source {
offset = 12 // src offset
} else {
offset = 16 // dst offset
}
ones := prefix.Bits()
@@ -1472,48 +1324,3 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
return exprs
}
func getCtNewExprs() []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
}
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
}
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
}, nil
}

View File

@@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
}
// Build CIDR matching expressions
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
// Combine all expressions in the correct order
// nolint:gocritic
@@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0
for _, chain := range rtr.chains {
if chain.Name == chainNameManglePrerouting {
if chain.Name == chainNamePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
@@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
// Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
@@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
// Verify the rule was removed
found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
@@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
setName := firewall.GenerateSetName(tt.sources)
set, err := r.createIpSet(setName, tt.sources)
if err != nil {
t.Logf("Failed to create IP set: %v", err)
printNftSets()

View File

@@ -3,20 +3,21 @@ package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() wgaddr.Address {
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}

View File

@@ -15,8 +15,8 @@ var (
Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: false,
},
},
@@ -24,8 +24,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: true,
},
},
@@ -40,8 +40,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: true,
},
},

View File

@@ -4,36 +4,39 @@ package uspfilter
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Close cleans up the firewall manager by removing all rules and closing trackers
// Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.logger != nil {

View File

@@ -3,13 +3,13 @@ package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -21,28 +21,31 @@ const (
firewallRuleName = "Netbird"
)
// Close cleans up the firewall manager by removing all rules and closing trackers
// Close closes the firewall manager
func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.logger != nil {

View File

@@ -3,14 +3,14 @@ package common
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() wgaddr.Address
Address() iface.WGAddress
GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice
}

View File

@@ -2,6 +2,7 @@ package conntrack
import (
"fmt"
"net"
"net/netip"
"sync/atomic"
"time"
@@ -13,15 +14,13 @@ import (
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
FlowId uuid.UUID
Direction nftypes.Direction
SourceIP netip.Addr
DestIP netip.Addr
lastSeen atomic.Int64
PacketsTx atomic.Uint64
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
FlowId uuid.UUID
Direction nftypes.Direction
SourceIP netip.Addr
DestIP netip.Addr
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64
}
// these small methods will be inlined by the compiler
@@ -31,17 +30,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano())
}
// UpdateCounters safely updates the packet and byte counters
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
if direction == nftypes.Egress {
b.PacketsTx.Add(1)
b.BytesTx.Add(uint64(bytes))
} else {
b.PacketsRx.Add(1)
b.BytesRx.Add(uint64(bytes))
}
}
// GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load())
@@ -64,3 +52,16 @@ type ConnKey struct {
func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
}
// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
SrcPort: srcPort,
DstPort: dstPort,
}
}

View File

@@ -1,7 +1,8 @@
package conntrack
import (
"net/netip"
"context"
"net"
"testing"
"github.com/sirupsen/logrus"
@@ -11,7 +12,7 @@ import (
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger()
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
@@ -20,22 +21,22 @@ func BenchmarkMemoryPressure(b *testing.B) {
defer tracker.Close()
// Generate different IPs
srcIPs := make([]netip.Addr, 100)
dstIPs := make([]netip.Addr, 100)
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
}
}
})
@@ -45,22 +46,22 @@ func BenchmarkMemoryPressure(b *testing.B) {
defer tracker.Close()
// Generate different IPs
srcIPs := make([]netip.Addr, 100)
dstIPs := make([]netip.Addr, 100)
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
}
}
})

View File

@@ -1,8 +1,8 @@
package conntrack
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
@@ -23,13 +23,14 @@ const (
// ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct {
SrcIP netip.Addr
DstIP netip.Addr
ID uint16
SrcIP netip.Addr
DstIP netip.Addr
Sequence uint16
ID uint16
}
func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
return fmt.Sprintf("%s -> %s (%d/%d)", i.SrcIP, i.DstIP, i.ID, i.Sequence)
}
// ICMPConnTrack represents an ICMP connection state
@@ -45,8 +46,8 @@ type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
done chan struct{}
flowLogger nftypes.FlowLogger
}
@@ -56,27 +57,21 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
timeout = DefaultICMPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &ICMPTracker{
logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
tickerCancel: cancel,
done: make(chan struct{}),
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
go tracker.cleanupRoutine()
return tracker
}
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
key := ICMPConnKey{
SrcIP: srcIP,
DstIP: dstIP,
ID: id,
}
func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) {
key := makeICMPKey(srcIP, dstIP, id, seq)
t.mutex.RLock()
conn, exists := t.connections[key]
@@ -84,7 +79,6 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
@@ -93,21 +87,22 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) {
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) {
// TODO: icmp doesn't need to extend the timeout
key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
if exists {
return
}
@@ -117,7 +112,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
t.sendStartEvent(direction, key, typ, code)
return
}
@@ -125,34 +120,29 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
},
ICMPType: typ,
ICMPCode: code,
}
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
t.sendEvent(nftypes.TypeStart, key, conn)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
key := ICMPConnKey{
SrcIP: dstIP,
DstIP: srcIP,
ID: id,
}
key := makeICMPKey(dstIP, srcIP, id, seq)
t.mutex.RLock()
conn, exists := t.connections[key]
@@ -163,19 +153,16 @@ func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true
}
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
defer t.tickerCancel()
func (t *ICMPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-ctx.Done():
case <-t.done:
return
}
}
@@ -189,58 +176,56 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Debug("Removed ICMP connection %s (timeout)", &key)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.tickerCancel()
t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock()
t.connections = nil
t.mutex.Unlock()
}
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
fields := nftypes.EventFields{
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, key ICMPConnKey, typ, code uint8) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction,
Protocol: nftypes.ICMP,
SourceIP: srcIP,
DestIP: dstIP,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
ICMPType: typ,
ICMPCode: code,
}
if direction == nftypes.Ingress {
fields.RxPackets = 1
fields.RxBytes = uint64(size)
} else {
fields.TxPackets = 1
fields.TxBytes = uint64(size)
}
t.flowLogger.StoreEvent(fields)
})
}
// makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ICMPConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
ID: id,
Sequence: seq,
}
}

View File

@@ -1,7 +1,7 @@
package conntrack
import (
"net/netip"
"net"
"testing"
)
@@ -10,12 +10,12 @@ func BenchmarkICMPTracker(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535), 0)
}
})
@@ -23,17 +23,17 @@ func BenchmarkICMPTracker(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i), 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
}
})
}

View File

@@ -3,8 +3,7 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections
import (
"context"
"net/netip"
"net"
"sync"
"sync/atomic"
"time"
@@ -23,11 +22,11 @@ const (
)
const (
TCPFin uint8 = 0x01
TCPSyn uint8 = 0x02
TCPAck uint8 = 0x10
TCPFin uint8 = 0x01
TCPRst uint8 = 0x04
TCPPush uint8 = 0x08
TCPAck uint8 = 0x10
TCPUrg uint8 = 0x20
)
@@ -41,7 +40,7 @@ const (
)
// TCPState represents the state of a TCP connection
type TCPState int32
type TCPState int
func (s TCPState) String() string {
switch s {
@@ -89,25 +88,20 @@ const (
// TCPConnTrack represents a TCP connection state
type TCPConnTrack struct {
BaseConnTrack
SourcePort uint16
DestPort uint16
state atomic.Int32
tombstone atomic.Bool
State TCPState
established atomic.Bool
tombstone atomic.Bool
sync.RWMutex
}
// GetState safely retrieves the current state
func (t *TCPConnTrack) GetState() TCPState {
return TCPState(t.state.Load())
// IsEstablished safely checks if connection is established
func (t *TCPConnTrack) IsEstablished() bool {
return t.established.Load()
}
// SetState safely updates the current state
func (t *TCPConnTrack) SetState(state TCPState) {
t.state.Store(int32(state))
}
// CompareAndSwapState atomically changes the state from old to new if current == old
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
return t.state.CompareAndSwap(int32(old), int32(newState))
// SetEstablished safely sets the established state
func (t *TCPConnTrack) SetEstablished(state bool) {
t.established.Store(state)
}
// IsTombstone safely checks if the connection is marked for deletion
@@ -126,51 +120,43 @@ type TCPTracker struct {
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
done chan struct{}
timeout time.Duration
waitTimeout time.Duration
flowLogger nftypes.FlowLogger
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
waitTimeout := TimeWaitTimeout
if timeout == 0 {
timeout = DefaultTCPTimeout
} else {
waitTimeout = timeout / 45
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
done: make(chan struct{}),
timeout: timeout,
waitTimeout: waitTimeout,
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
go tracker.cleanupRoutine()
return tracker
}
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
t.updateState(key, conn, flags, direction, size)
conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.UpdateLastSeen()
conn.Unlock()
return key, true
}
@@ -178,181 +164,172 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress)
}
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists || flags&TCPSyn == 0 {
func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags)
if exists {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
FlowId: uuid.New(),
Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: srcPort,
DestPort: dstPort,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.UpdateLastSeen()
conn.established.Store(false)
conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew))
t.logger.Trace("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction, size)
t.updateState(key, conn, flags, direction == nftypes.Egress)
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn, ruleID)
t.sendEvent(nftypes.TypeStart, key, conn)
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.IsTombstone() {
if !exists {
return false
}
currentState := conn.GetState()
if !t.isValidStateForFlags(currentState, flags) {
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
// allow all flags for established for now
if currentState == TCPStateEstablished {
// Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
if conn.IsTombstone() {
return true
}
return false
conn.Lock()
conn.SetTombstone()
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
return true
}
t.updateState(key, conn, flags, nftypes.Ingress, size)
return true
conn.Lock()
t.updateState(key, conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock()
return isEstablished || isValidState
}
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
conn.UpdateLastSeen()
conn.UpdateCounters(packetDir, size)
currentState := conn.GetState()
if flags&TCPRst != 0 {
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
conn.SetTombstone()
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
state := conn.State
defer func() {
if state != conn.State {
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
}
return
}
}()
var newState TCPState
switch currentState {
switch state {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
if conn.Direction == nftypes.Egress {
newState = TCPStateSynSent
} else {
newState = TCPStateSynReceived
}
conn.State = TCPStateSynSent
}
case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if packetDir != conn.Direction {
newState = TCPStateEstablished
if isOutbound {
conn.State = TCPStateEstablished
conn.SetEstablished(true)
} else {
// Simultaneous open
newState = TCPStateSynReceived
conn.State = TCPStateSynReceived
}
}
case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
if packetDir == conn.Direction {
newState = TCPStateEstablished
}
conn.State = TCPStateEstablished
conn.SetEstablished(true)
}
case TCPStateEstablished:
if flags&TCPFin != 0 {
if packetDir == conn.Direction {
newState = TCPStateFinWait1
if isOutbound {
conn.State = TCPStateFinWait1
} else {
newState = TCPStateCloseWait
conn.State = TCPStateCloseWait
}
conn.SetEstablished(false)
}
case TCPStateFinWait1:
if packetDir != conn.Direction {
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
newState = TCPStateClosing
case flags&TCPFin != 0:
newState = TCPStateClosing
case flags&TCPAck != 0:
newState = TCPStateFinWait2
}
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
// Simultaneous close - both sides sent FIN
conn.State = TCPStateClosing
case flags&TCPFin != 0:
conn.State = TCPStateFinWait2
case flags&TCPAck != 0:
conn.State = TCPStateFinWait2
}
case TCPStateFinWait2:
if flags&TCPFin != 0 {
newState = TCPStateTimeWait
conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
case TCPStateClosing:
if flags&TCPAck != 0 {
newState = TCPStateTimeWait
conn.State = TCPStateTimeWait
// Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
case TCPStateCloseWait:
if flags&TCPFin != 0 {
newState = TCPStateLastAck
conn.State = TCPStateLastAck
}
case TCPStateLastAck:
if flags&TCPAck != 0 {
newState = TCPStateClosed
}
}
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
switch newState {
case TCPStateTimeWait:
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.State = TCPStateClosed
conn.SetTombstone()
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
// Send close event for gracefully closed connections
t.sendEvent(nftypes.TypeEnd, key, conn)
t.logger.Trace("TCP connection %s closed gracefully", key)
}
}
}
@@ -362,22 +339,18 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
if flags&TCPRst != 0 {
if state == TCPStateSynSent {
return flags&TCPAck != 0
}
return true
}
switch state {
case TCPStateNew:
return flags&TCPSyn != 0 && flags&TCPAck == 0
case TCPStateSynSent:
// TODO: support simultaneous open
return flags&TCPSyn != 0 && flags&TCPAck != 0
case TCPStateSynReceived:
return flags&TCPAck != 0
case TCPStateEstablished:
if flags&TCPRst != 0 {
return true
}
return flags&TCPAck != 0
case TCPStateFinWait1:
return flags&TCPFin != 0 || flags&TCPAck != 0
@@ -394,20 +367,20 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
case TCPStateLastAck:
return flags&TCPAck != 0
case TCPStateClosed:
// Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK
// Accept retransmitted ACKs in closed state
// This is important because the final ACK might be lost
// and the peer will retransmit their FIN-ACK
return flags&TCPAck != 0
}
return false
}
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
func (t *TCPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-ctx.Done():
case <-t.done:
return
}
}
@@ -425,25 +398,24 @@ func (t *TCPTracker) cleanup() {
}
var timeout time.Duration
currentState := conn.GetState()
switch currentState {
case TCPStateTimeWait:
timeout = t.waitTimeout
case TCPStateEstablished:
switch {
case conn.State == TCPStateTimeWait:
timeout = TimeWaitTimeout
case conn.IsEstablished():
timeout = t.timeout
default:
timeout = TCPHandshakeTimeout
}
if conn.timeoutExceeded(timeout) {
// Return IPs to pool
delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.logger.Trace("Cleaned up timed-out TCP connection %s", &key)
// event already handled by state change
if currentState != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn, nil)
if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, key, conn)
}
}
}
@@ -451,7 +423,8 @@ func (t *TCPTracker) cleanup() {
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.tickerCancel()
t.cleanupTicker.Stop()
close(t.done)
// Clean up all remaining IPs
t.mutex.Lock()
@@ -473,20 +446,15 @@ func isValidFlagCombination(flags uint8) bool {
return true
}
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.TCP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: key.SrcPort,
DestPort: key.DstPort,
})
}

View File

@@ -1,83 +0,0 @@
package conntrack
import (
"net/netip"
"testing"
"time"
)
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
defer tracker.Close()
// Pre-populate with expired connections
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
}

View File

@@ -1,11 +1,10 @@
package conntrack
import (
"net/netip"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -13,8 +12,8 @@ func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
@@ -59,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
})
}
@@ -77,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper()
// Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
// Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
// Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
require.True(t, valid, "Data should be allowed after handshake")
},
},
@@ -100,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "FIN should be allowed")
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
},
},
{
@@ -123,8 +122,11 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets
// The connection will be cleaned up by timeout
},
},
{
@@ -136,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "Final ACKs should be allowed")
},
},
@@ -163,8 +165,8 @@ func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
@@ -179,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established",
setupState: func() {
// Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
},
wantValid: true,
desc: "Should accept RST for established connection",
@@ -193,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection",
setupState: func() {},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
},
wantValid: false,
desc: "Should reject RST without connection",
@@ -206,455 +208,101 @@ func TestRSTHandling(t *testing.T) {
tt.sendRST()
// Verify connection state is as expected
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn := tracker.connections[key]
if tt.wantValid {
require.NotNil(t, conn)
require.Equal(t, TCPStateClosed, conn.GetState())
require.Equal(t, TCPStateClosed, conn.State)
require.False(t, conn.IsEstablished())
}
})
}
}
func TestTCPRetransmissions(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test SYN retransmission
t.Run("SYN Retransmission", func(t *testing.T) {
// Initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Retransmit SYN (should not affect the state machine)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Verify we're still in SYN-SENT state
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateSynSent, conn.GetState())
// Complete the handshake
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Verify we're in ESTABLISHED state
require.Equal(t, TCPStateEstablished, conn.GetState())
})
// Test ACK retransmission in established state
t.Run("ACK Retransmission", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateEstablished, conn.GetState())
// Retransmit ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// State should remain ESTABLISHED
require.Equal(t, TCPStateEstablished, conn.GetState())
})
// Test FIN retransmission
t.Run("FIN Retransmission", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Retransmit FIN (should not change state)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
})
}
func TestTCPDataTransfer(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Data Transfer", func(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
// Receive ACK for data
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
require.True(t, valid)
// Receive data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
require.True(t, valid)
// Send ACK for received data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
// State should remain ESTABLISHED
require.Equal(t, TCPStateEstablished, conn.GetState())
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
})
}
func TestTCPHalfClosedConnections(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test half-closed connection: local end closes, remote end continues sending data
t.Run("Local Close, Remote Data", func(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Remote end can still send data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
require.True(t, valid)
// We can still ACK their data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Receive FIN from remote end
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// State should remain TIME-WAIT (waiting for possible retransmissions)
require.Equal(t, TCPStateTimeWait, conn.GetState())
})
// Test half-closed connection: remote end closes, local end continues sending data
t.Run("Remote Close, Local Data", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Receive FIN from remote
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// We can still send data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
// Remote can still ACK our data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Send our FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Receive final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateClosed, conn.GetState())
})
}
func TestTCPAbnormalSequences(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test handling of unsolicited RST in various states
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
// Send SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive unsolicited RST (without proper ACK)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
// Receive RST with proper ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.Equal(t, TCPStateClosed, conn.GetState())
require.True(t, conn.IsTombstone())
})
}
func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Connection Timeout", func(t *testing.T) {
// Establish a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateEstablished, conn.GetState())
// Wait for the connection to timeout
time.Sleep(2 * shortTimeout)
// Force cleanup
tracker.cleanup()
// Connection should be removed
_, exists := tracker.connections[key]
require.False(t, exists, "Connection should be removed after timeout")
})
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Complete the connection close to enter TIME_WAIT
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// TIME_WAIT should have its own timeout value (usually 2*MSL)
// For the test, we're using a short timeout
time.Sleep(2 * shortTimeout)
tracker.cleanup()
// Connection should be removed
_, exists := tracker.connections[key]
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
})
}
func TestSynFlood(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
basePort := uint16(10000)
dstPort := uint16(80)
// Create a large number of SYN packets to simulate a SYN flood
for i := uint16(0); i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
}
// Check that we're tracking all connections
require.Equal(t, 1000, len(tracker.connections))
// Now simulate SYN timeout
var oldConns int
tracker.mutex.Lock()
for _, conn := range tracker.connections {
if conn.GetState() == TCPStateSynSent {
// Make the connection appear old
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
oldConns++
}
}
tracker.mutex.Unlock()
require.Equal(t, 1000, oldConns)
// Run cleanup
tracker.cleanup()
// Check that stale connections were cleaned up
require.Equal(t, 0, len(tracker.connections))
}
func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := netip.MustParseAddr("100.64.0.1")
serverIP := netip.MustParseAddr("100.64.0.2")
clientPort := uint16(12345)
serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
key := ConnKey{
SrcIP: clientIP,
DstIP: serverIP,
SrcPort: clientPort,
DstPort: serverPort,
}
tracker.mutex.RLock()
conn := tracker.connections[key]
tracker.mutex.RUnlock()
require.NotNil(t, conn)
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
// 2. Server sends SYN-ACK response
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer
// Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
// Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
// Server sends data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
// Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState())
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
}
// Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
}

View File

@@ -1,8 +1,7 @@
package conntrack
import (
"context"
"net/netip"
"net"
"sync"
"time"
@@ -22,8 +21,6 @@ const (
// UDPConnTrack represents a UDP connection state
type UDPConnTrack struct {
BaseConnTrack
SourcePort uint16
DestPort uint16
}
// UDPTracker manages UDP connection states
@@ -32,8 +29,8 @@ type UDPTracker struct {
connections map[ConnKey]*UDPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
done chan struct{}
flowLogger nftypes.FlowLogger
}
@@ -43,41 +40,34 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
timeout = DefaultUDPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &UDPTracker{
logger: logger,
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
tickerCancel: cancel,
done: make(chan struct{}),
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress)
}
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
t.mutex.RLock()
conn, exists := t.connections[key]
@@ -85,7 +75,6 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
@@ -93,41 +82,35 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort)
if exists {
return
}
conn := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
FlowId: uuid.New(),
Direction: direction,
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: srcPort,
DestPort: dstPort,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
t.sendEvent(nftypes.TypeStart, key, conn)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
@@ -138,20 +121,17 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
func (t *UDPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-ctx.Done():
case <-t.done:
return
}
}
@@ -165,16 +145,16 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("Removed UDP connection %s (timeout)", key)
t.sendEvent(nftypes.TypeEnd, key, conn)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() {
t.tickerCancel()
t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock()
t.connections = nil
@@ -182,16 +162,11 @@ func (t *UDPTracker) Close() {
}
// GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := t.connections[key]
return conn, exists
}
@@ -201,20 +176,15 @@ func (t *UDPTracker) Timeout() time.Duration {
return t.timeout
}
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.UDP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
SourceIP: key.SrcIP,
DestIP: key.DstIP,
SourcePort: key.SrcPort,
DestPort: key.DstPort,
})
}

View File

@@ -1,7 +1,7 @@
package conntrack
import (
"context"
"net"
"net/netip"
"testing"
"time"
@@ -35,7 +35,7 @@ func TestNewUDPTracker(t *testing.T) {
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.tickerCancel)
assert.NotNil(t, tracker.done)
})
}
}
@@ -49,15 +49,10 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
// Verify connection was tracked
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort)
conn, exists := tracker.connections[key]
require.True(t, exists)
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
@@ -71,18 +66,18 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := netip.MustParseAddr("192.168.1.3")
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
// Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tests := []struct {
name string
srcIP netip.Addr
dstIP netip.Addr
srcIP net.IP
dstIP net.IP
srcPort uint16
dstPort uint16
sleep time.Duration
@@ -99,7 +94,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
},
{
name: "invalid source IP",
srcIP: netip.MustParseAddr("192.168.1.4"),
srcIP: net.ParseIP("192.168.1.4"),
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
@@ -109,7 +104,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
{
name: "invalid destination IP",
srcIP: dstIP,
dstIP: netip.MustParseAddr("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.4"),
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
@@ -149,7 +144,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 {
time.Sleep(tt.sleep)
}
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
assert.Equal(t, tt.want, got)
})
}
@@ -160,45 +155,42 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond
ctx, tickerCancel := context.WithCancel(context.Background())
defer tickerCancel()
// Create tracker with custom cleanup interval
tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval),
tickerCancel: tickerCancel,
done: make(chan struct{}),
logger: logger,
flowLogger: flowLogger,
}
// Start cleanup routine
go tracker.cleanupRoutine(ctx)
go tracker.cleanupRoutine()
// Add some connections
connections := []struct {
srcIP netip.Addr
dstIP netip.Addr
srcIP net.IP
dstIP net.IP
srcPort uint16
dstPort uint16
}{
{
srcIP: netip.MustParseAddr("192.168.1.2"),
dstIP: netip.MustParseAddr("192.168.1.3"),
srcIP: net.ParseIP("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"),
srcPort: 12345,
dstPort: 53,
},
{
srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: netip.MustParseAddr("192.168.1.5"),
srcIP: net.ParseIP("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"),
srcPort: 12346,
dstPort: 53,
},
}
for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
}
// Verify initial connections
@@ -223,12 +215,12 @@ func BenchmarkUDPTracker(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
}
})
@@ -236,17 +228,17 @@ func BenchmarkUDPTracker(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
}
})
}

View File

@@ -4,9 +4,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
"runtime"
"sync"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
@@ -19,7 +17,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -32,10 +29,8 @@ const (
)
type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
logger *nblog.Logger
flowLogger nftypes.FlowLogger
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
@@ -172,35 +167,3 @@ func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
}
return addr.AsSlice()
}
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
key := buildKey(srcIP, dstIP, srcPort, dstPort)
f.ruleIdMap.LoadOrStore(key, ruleID)
}
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
return value.([]byte), true
}
return nil, false
}
func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return
}
f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort))
}
func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey {
return conntrack.ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
}

View File

@@ -15,17 +15,14 @@ import (
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
flowID := uuid.New()
// Extract ICMP header to get type and code
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
@@ -34,55 +31,72 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
defer func() {
if err := conn.Close(); err != nil {
f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err)
f.logger.Debug("Failed to close ICMP socket: %v", err)
}
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
}()
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data)
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
// For Echo Requests, send and handle response
switch icmpHdr.Type() {
case header.ICMPv4Echo:
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id, flowID)
case header.ICMPv4EchoReply:
// dont process our own replies
return true
default:
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v",
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
rxBytes := pkt.Size()
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID, flowID uuid.UUID) bool {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err)
return 0
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return true
}
response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
f.logger.Error("forwarder: Failed to read ICMP response: %v", err)
f.logger.Error("Failed to read ICMP response: %v", err)
}
return 0
return true
}
ipHdr := make([]byte, header.IPv4MinimumSize)
@@ -101,54 +115,28 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("forwarder: Failed to inject ICMP response: %v", err)
f.logger.Error("Failed to inject ICMP response: %v", err)
return 0
return true
}
f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)
return true
}
// sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) {
var rxPackets, txPackets uint64
if rxBytes > 0 {
rxPackets = 1
}
if txBytes > 0 {
txPackets = 1
}
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.ICMP,
// TODO: handle ipv6
SourceIP: srcIp,
DestIP: dstIp,
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
ICMPType: icmpType,
ICMPCode: icmpCode,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
}
if typ == nftypes.TypeStart {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
fields.RuleID = ruleId
}
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
}
f.flowLogger.StoreEvent(fields)
})
}

View File

@@ -6,10 +6,8 @@ import (
"io"
"net"
"net/netip"
"sync"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -24,14 +22,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
}
}()
f.sendTCPEvent(nftypes.TypeStart, flowID, id)
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
@@ -60,105 +51,64 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep, flowID)
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id)
}()
// Create context for managing the proxy goroutines
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesFromInToOut int64 // bytes from client to server (tx for client)
bytesFromOutToIn int64 // bytes from server to client (rx for client)
errInToOut error
errOutToIn error
)
errChan := make(chan error, 2)
go func() {
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
cancel()
wg.Done()
_, err := io.Copy(outConn, inConn)
errChan <- err
}()
go func() {
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
cancel()
wg.Done()
_, err := io.Copy(inConn, outConn)
errChan <- err
}()
wg.Wait()
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut)
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
return
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn)
}
}
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
fields := nftypes.EventFields{
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
Protocol: 6,
// TODO: handle ipv6
SourceIP: srcIp,
DestIP: dstIp,
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
}
if typ == nftypes.TypeStart {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
fields.RuleID = ruleId
}
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
}
f.flowLogger.StoreEvent(fields)
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}

View File

@@ -89,6 +89,21 @@ func (f *udpForwarder) Stop() {
}
}
// sendUDPEvent stores flow events for UDP connections
func (f *udpForwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: 17,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}
// cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute)
@@ -125,6 +140,8 @@ func (f *udpForwarder) cleanup() {
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
f.sendUDPEvent(nftypes.TypeEnd, idle.conn.flowID, idle.id)
}
}
}
@@ -148,19 +165,13 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
}
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
}
}()
f.sendUDPEvent(nftypes.TypeStart, flowID, id)
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
// TODO: Send ICMP error message
return
}
@@ -173,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return
}
@@ -199,109 +211,72 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
}
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
defer func() {
pConn.cancel()
if err := pConn.conn.Close(); err != nil && !isClosedError(err) {
if err := pConn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
}
if err := pConn.outConn.Close(); err != nil && !isClosedError(err) {
if err := pConn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
ep.Close()
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id)
}()
var wg sync.WaitGroup
wg.Add(2)
errChan := make(chan error, 2)
var txBytes, rxBytes int64
var outboundErr, inboundErr error
// outbound->inbound: copy from pConn.conn to pConn.outConn
go func() {
defer wg.Done()
txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}()
// inbound->outbound: copy from pConn.outConn to pConn.conn
go func() {
defer wg.Done()
rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}()
wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr)
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
return
}
if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr)
}
var rxPackets, txPackets uint64
if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
rxPackets = udpStats.PacketsSent.Value()
txPackets = udpStats.PacketsReceived.Value()
}
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets)
}
// sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
// sendUDPEvent stores flow events for UDP connections, mirrors the TCP version
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
Protocol: 17, // UDP protocol number
// TODO: handle ipv6
SourceIP: srcIp,
DestIP: dstIp,
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
}
if typ == nftypes.TypeStart {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
fields.RuleID = ruleId
}
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
}
f.flowLogger.StoreEvent(fields)
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}
func (c *udpPacketConn) updateLastSeen() {
@@ -313,20 +288,18 @@ func (c *udpPacketConn) getIdleDuration() time.Duration {
return time.Since(lastSeen)
}
// copy reads from src and writes to dst.
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) {
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp)
buffer := *bufp
var totalBytes int64 = 0
for {
if ctx.Err() != nil {
return totalBytes, ctx.Err()
return ctx.Err()
}
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
return totalBytes, fmt.Errorf("set read deadline: %w", err)
return fmt.Errorf("set read deadline: %w", err)
}
n, err := src.Read(buffer)
@@ -334,15 +307,14 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
if isTimeout(err) {
continue
}
return totalBytes, fmt.Errorf("read from %s: %w", direction, err)
return fmt.Errorf("read from %s: %w", direction, err)
}
nWritten, err := dst.Write(buffer[:n])
_, err = dst.Write(buffer[:n])
if err != nil {
return totalBytes, fmt.Errorf("write to %s: %w", direction, err)
return fmt.Errorf("write to %s: %w", direction, err)
}
totalBytes += int64(nWritten)
c.updateLastSeen()
}
}

View File

@@ -3,7 +3,6 @@ package uspfilter
import (
"fmt"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
@@ -14,13 +13,8 @@ import (
type localIPManager struct {
mu sync.RWMutex
// fixed-size high array for upper byte of a IPv4 address
ipv4Bitmap [256]*ipv4LowBitmap
}
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
type ipv4LowBitmap struct {
bitmap [8192]uint32
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
ipv4Bitmap [1 << 16]uint32
}
func newLocalIPManager() *localIPManager {
@@ -32,59 +26,39 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
if ipv4 == nil {
return
}
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
index := low / 32
bit := low % 32
if m.ipv4Bitmap[high] == nil {
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
}
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
m.ipv4Bitmap[high] |= 1 << (low % 32)
}
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
ipv4 := ip.To4()
if ipv4 == nil {
return false
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
}
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
if ipv4 := ip.To4(); ipv4 != nil {
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
if int(high) >= len(*newIPv4Bitmap) {
return fmt.Errorf("invalid IPv4 address: %s", ip)
}
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
ipStr := ipv4.String()
ipStr := ip.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
newIPv4Bitmap[high] |= 1 << (low % 32)
}
}
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := uint16(ip[0])
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
if m.ipv4Bitmap[high] == nil {
return false
}
index := low / 32
bit := low % 32
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
}
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -102,7 +76,7 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue
}
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
@@ -115,14 +89,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}
}()
var newIPv4Bitmap [256]*ipv4LowBitmap
var newIPv4Bitmap [1 << 16]uint32
ipv4Set := make(map[string]struct{})
var ipv4Addresses []string
// 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := 0; i < 8192; i++ {
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
high := uint16(127) << 8
for i := uint16(0); i < 256; i++ {
newIPv4Bitmap[high|i] = 0xffffffff
}
if iface != nil {
@@ -148,13 +122,13 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
return nil
}
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
if !ip.Is4() {
return false
}
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
m.mu.RLock()
defer m.mu.RUnlock()
return m.checkBitmapBit(ip.AsSlice())
if ipv4 := ip.To4(); ipv4 != nil {
return m.checkBitmapBit(ipv4)
}
return false
}

View File

@@ -2,103 +2,90 @@ package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
)
func TestLocalIPManager(t *testing.T) {
tests := []struct {
name string
setupAddr wgaddr.Address
testIP netip.Addr
setupAddr iface.WGAddress
testIP net.IP
expected bool
}{
{
name: "Localhost range",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.2"),
testIP: net.ParseIP("127.0.0.2"),
expected: true,
},
{
name: "Localhost standard address",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.1"),
testIP: net.ParseIP("127.0.0.1"),
expected: true,
},
{
name: "Localhost range edge",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.255.255.255"),
testIP: net.ParseIP("127.255.255.255"),
expected: true,
},
{
name: "Local IP matches",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.1"),
testIP: net.ParseIP("192.168.1.1"),
expected: true,
},
{
name: "Local IP doesn't match",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.2"),
expected: false,
},
{
name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.33"),
testIP: net.ParseIP("192.168.1.2"),
expected: false,
},
{
name: "IPv6 address",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
},
testIP: netip.MustParseAddr("fe80::1"),
testIP: net.ParseIP("fe80::1"),
expected: false,
},
}
@@ -108,7 +95,7 @@ func TestLocalIPManager(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() wgaddr.Address {
AddressFunc: func() iface.WGAddress {
return tt.setupAddr
},
}
@@ -187,7 +174,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
result := manager.IsLocalIP(net.ParseIP(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
})
}
@@ -204,8 +191,10 @@ func BenchmarkIPChecks(b *testing.B) {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
}
// Setup bitmap
bitmapManager := newLocalIPManager()
// Setup bitmap version
bitmapManager := &localIPManager{
ipv4Bitmap: [1 << 16]uint32{},
}
for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip)
}
@@ -258,7 +247,7 @@ func BenchmarkWGPosition(b *testing.B) {
// Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) {
bm := newLocalIPManager()
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
bm.setBitmapBit(wgIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@@ -267,7 +256,7 @@ func BenchmarkWGPosition(b *testing.B) {
})
b.Run("WG_Last", func(b *testing.B) {
bm := newLocalIPManager()
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
// Fill with other IPs first
for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"net"
"net/netip"
"github.com/google/gopacket"
@@ -12,7 +13,7 @@ import (
type PeerRule struct {
id string
mgmtId []byte
ip netip.Addr
ip net.IP
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType
@@ -29,15 +30,14 @@ func (r *PeerRule) ID() string {
}
type RouteRule struct {
id string
mgmtId []byte
sources []netip.Prefix
dstSet firewall.Set
destinations []netip.Prefix
proto firewall.Protocol
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
id string
mgmtId []byte
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
}
// ID returns the rule id

View File

@@ -2,7 +2,7 @@ package uspfilter
import (
"fmt"
"net/netip"
"net"
"time"
"github.com/google/gopacket"
@@ -53,8 +53,8 @@ type TraceResult struct {
}
type PacketTrace struct {
SourceIP netip.Addr
DestinationIP netip.Addr
SourceIP net.IP
DestinationIP net.IP
Protocol string
SourcePort uint16
DestinationPort uint16
@@ -72,8 +72,8 @@ type TCPState struct {
}
type PacketBuilder struct {
SrcIP netip.Addr
DstIP netip.Addr
SrcIP net.IP
DstIP net.IP
Protocol fw.Protocol
SrcPort uint16
DstPort uint16
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
Version: 4,
TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(),
SrcIP: p.SrcIP,
DstIP: p.DstIP,
}
}
@@ -260,30 +260,28 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
}
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace
}
if m.localipmanager.IsLocalIP(dstIP) {
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
if !m.handleRouting(trace) {
return trace
}
if m.nativeRouter.Load() {
if m.nativeRouter {
return m.handleNativeRouter(trace)
}
return m.handleRouteACLs(trace, d, srcIP, dstIP)
}
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
msg := "No existing connection found"
if allowed {
msg = m.buildConntrackStateMessage(d)
@@ -311,46 +309,39 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
return msg
}
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
if !m.localForwarding {
trace.AddResult(StageRouting, "Local forwarding disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
return true
}
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
strRuleId := "<no id>"
strRuleId := "implicit"
if ruleId != nil {
strRuleId = string(ruleId)
}
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
if blocked {
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
trace.AddResult(StagePeerACL, msg, false)
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
return true
}
trace.AddResult(StagePeerACL, msg, true)
trace.AddResult(StagePeerACL, msg, !blocked)
// Handle netstack mode
if m.netstack {
switch {
case !m.localForwarding:
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
case m.forwarder.Load() != nil:
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
default:
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
}
return true
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
}
// In normal mode, packets are allowed through for local delivery
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
return true
}
func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled.Load() {
if !m.routingEnabled {
trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false
@@ -366,14 +357,14 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
return trace
}
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
strId := string(id)
if id == nil {
strId = "<no id>"
strId = "implicit"
}
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
@@ -382,7 +373,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
}
trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder.Load() != nil {
if allowed && m.forwarder != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
}
@@ -401,7 +392,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.processOutgoingHooks(packetData, 0)
dropped := m.processOutgoingHooks(packetData)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {

View File

@@ -1,440 +0,0 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
t.Logf("Trace results: %v", trace.Results)
actualStages := make([]PacketStage, 0, len(trace.Results))
for _, result := range trace.Results {
actualStages = append(actualStages, result.Stage)
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
}
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
}
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
require.NotEmpty(t, trace.Results, "Trace should have results")
lastResult := trace.Results[len(trace.Results)-1]
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
}
func TestTracePacket(t *testing.T) {
setupTracerTest := func(statefulMode bool) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
m, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
if !statefulMode {
m.stateful = false
}
return m
}
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
builder := &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: protocol,
SrcPort: srcPort,
DstPort: dstPort,
Direction: direction,
}
if protocol == "tcp" {
builder.TCPState = &TCPState{SYN: true}
}
return builder
}
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
return &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: "icmp",
ICMPType: icmpType,
ICMPCode: icmpCode,
Direction: direction,
}
}
testCases := []struct {
name string
setup func(*Manager)
packetBuilder func() *PacketBuilder
expectedStages []PacketStage
expectedAllow bool
}{
{
name: "LocalTraffic_ACLAllowed",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_ACLDenied",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "LocalTraffic_WithForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = true
m.forwarder.Store(&forwarder.Forwarder{})
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_WithoutForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLAllowed",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.forwarder.Store(&forwarder.Forwarder{})
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLDenied",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "RoutedTraffic_NativeRouter",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_RoutingDisabled",
setup: func(m *Manager) {
m.routingEnabled.Store(false)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageCompleted,
},
expectedAllow: false,
},
{
name: "ConnectionTracking_Hit",
setup: func(m *Manager) {
srcIP := netip.MustParseAddr("100.10.0.100")
dstIP := netip.MustParseAddr("1.1.1.1")
srcPort := uint16(12345)
dstPort := uint16(80)
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
},
packetBuilder: func() *PacketBuilder {
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
pb.TCPState = &TCPState{SYN: true, ACK: true}
return pb
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageCompleted,
},
expectedAllow: true,
},
{
name: "OutboundTraffic",
setup: func(m *Manager) {
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPEchoRequest",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPDestinationUnreachable",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithoutHook",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "StatefulDisabled_NoTracking",
setup: func(m *Manager) {
m.stateful = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m := setupTracerTest(true)
tc.setup(m)
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
"100.10.0.100 should be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")),
"192.168.17.2 should not be recognized as a local IP")
pb := tc.packetBuilder()
trace, err := m.TracePacketFromBuilder(pb)
require.NoError(t, err)
verifyTraceStages(t, trace, tc.expectedStages)
verifyFinalDisposition(t, trace, tc.expectedAllow)
})
}
}

View File

@@ -10,7 +10,6 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
@@ -49,10 +48,10 @@ var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule
type RouteRules []*RouteRule
type RouteRules []RouteRule
func (r RouteRules) Sort() {
slices.SortStableFunc(r, func(a, b *RouteRule) int {
slices.SortStableFunc(r, func(a, b RouteRule) int {
// Deny rules come first
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
return -1
@@ -67,9 +66,9 @@ func (r RouteRules) Sort() {
// Manager userspace firewall manager
type Manager struct {
// outgoingRules is used for hooks only
outgoingRules map[netip.Addr]RuleSet
outgoingRules map[string]RuleSet
// incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet
incomingRules map[string]RuleSet
routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool
@@ -81,9 +80,9 @@ type Manager struct {
// indicates whether server routes are disabled
disableServerRoutes bool
// indicates whether we forward packets not destined for ourselves
routingEnabled atomic.Bool
routingEnabled bool
// indicates whether we leave forwarding and filtering to the native firewall
nativeRouter atomic.Bool
nativeRouter bool
// indicates whether we track outbound connections
stateful bool
// indicates whether wireguards runs in netstack mode
@@ -96,11 +95,9 @@ type Manager struct {
udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker
forwarder atomic.Pointer[forwarder.Forwarder]
forwarder *forwarder.Forwarder
logger *nblog.Logger
flowLogger nftypes.FlowLogger
blockRule firewall.Rule
}
// decoder for packages
@@ -171,18 +168,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
},
},
nativeFirewall: nativeFirewall,
outgoingRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[netip.Addr]RuleSet),
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
wgIface: iface,
localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes,
routingEnabled: false,
stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()),
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
}
m.routingEnabled.Store(false)
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err)
@@ -203,35 +200,41 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}
}
if err := m.blockInvalidRouted(iface); err != nil {
log.Errorf("failed to block invalid routed traffic: %v", err)
}
if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err)
}
return m, nil
}
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
if m.forwarder == nil {
return nil
}
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
if err != nil {
return nil, fmt.Errorf("parse wireguard network: %w", err)
return fmt.Errorf("parse wireguard network: %w", err)
}
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
rule, err := m.addRouteFiltering(
if _, err := m.AddRouteFiltering(
nil,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
firewall.Network{Prefix: wgPrefix},
wgPrefix,
firewall.ProtocolALL,
nil,
nil,
firewall.ActionDrop,
)
if err != nil {
return nil, fmt.Errorf("block wg nte : %w", err)
); err != nil {
return fmt.Errorf("block wg nte : %w", err)
}
// TODO: Block networks that we're a client of
return rule, nil
return nil
}
func (m *Manager) determineRouting() error {
@@ -252,20 +255,20 @@ func (m *Manager) determineRouting() error {
switch {
case disableUspRouting:
m.routingEnabled.Store(false)
m.nativeRouter.Store(false)
m.routingEnabled = false
m.nativeRouter = false
log.Info("userspace routing is disabled")
case m.disableServerRoutes:
// if server routes are disabled we will let packets pass to the native stack
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
m.routingEnabled = true
m.nativeRouter = true
log.Info("server routes are disabled")
case forceUserspaceRouter:
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.routingEnabled = true
m.nativeRouter = false
log.Info("userspace routing is forced")
@@ -273,19 +276,19 @@ func (m *Manager) determineRouting() error {
// if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
m.routingEnabled = true
m.nativeRouter = true
log.Info("native routing is enabled")
default:
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.routingEnabled = true
m.nativeRouter = false
log.Info("userspace routing enabled by default")
}
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
if m.routingEnabled && !m.nativeRouter {
return m.initForwarder()
}
@@ -294,24 +297,24 @@ func (m *Manager) determineRouting() error {
// initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error {
if m.forwarder.Load() != nil {
if m.forwarder != nil {
return nil
}
// Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := m.wgIface.GetWGDevice()
if intf == nil {
m.routingEnabled.Store(false)
m.routingEnabled = false
return errors.New("forwarding not supported")
}
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
if err != nil {
m.routingEnabled.Store(false)
m.routingEnabled = false
return fmt.Errorf("create forwarder: %w", err)
}
m.forwarder.Store(forwarder)
m.forwarder = forwarder
log.Debug("forwarder initialized")
@@ -327,7 +330,7 @@ func (m *Manager) IsServerRouteSupported() bool {
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair)
}
@@ -338,7 +341,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.RemoveNatRule(pair)
}
return nil
@@ -357,23 +360,17 @@ func (m *Manager) AddPeerFiltering(
action firewall.Action,
_ string,
) ([]firewall.Rule, error) {
// TODO: fix in upper layers
i, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("invalid IP: %s", ip)
}
i = i.Unmap()
r := PeerRule{
id: uuid.New().String(),
mgmtId: id,
ip: i,
ip: ip,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
drop: action == firewall.ActionDrop,
}
if i.Is4() {
if ipNormalized := ip.To4(); ipNormalized != nil {
r.ipLayer = layers.LayerTypeIPv4
r.ip = ipNormalized
}
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
@@ -398,10 +395,10 @@ func (m *Manager) AddPeerFiltering(
}
m.mutex.Lock()
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(RuleSet)
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip][r.id] = r
m.incomingRules[r.ip.String()][r.id] = r
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
}
@@ -409,65 +406,48 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
destination netip.Prefix,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) addRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := uuid.New().String()
rule := RouteRule{
// TODO: consolidate these IDs
id: ruleID,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
proto: proto,
srcPort: sPort,
dstPort: dPort,
action: action,
}
if destination.IsPrefix() {
rule.destinations = []netip.Prefix{destination.Prefix}
id: ruleID,
mgmtId: id,
sources: sources,
destination: destination,
proto: proto,
srcPort: sPort,
dstPort: dPort,
action: action,
}
m.routeRules = append(m.routeRules, &rule)
m.routeRules = append(m.routeRules, rule)
m.routeRules.Sort()
return &rule, nil
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteRouteRule(rule)
}
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule)
}
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := rule.ID()
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
return r.id == ruleID
})
if idx < 0 {
@@ -488,10 +468,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
if _, ok := m.incomingRules[r.ip][r.id]; !ok {
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip], r.id)
delete(m.incomingRules[r.ip.String()], r.id)
return nil
}
@@ -523,60 +503,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteDNATRule(rule)
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.UpdateSet(set, prefixes)
}
m.mutex.Lock()
defer m.mutex.Unlock()
var matches []*RouteRule
for _, rule := range m.routeRules {
if rule.dstSet == set {
matches = append(matches, rule)
}
}
if len(matches) == 0 {
return fmt.Errorf("no route rule found for set: %s", set)
}
destinations := matches[0].destinations
for _, prefix := range prefixes {
if prefix.Addr().Is4() {
destinations = append(destinations, prefix)
}
}
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
cmp := a.Addr().Compare(b.Addr())
if cmp != 0 {
return cmp
}
return a.Bits() - b.Bits()
})
destinations = slices.Compact(destinations)
for _, rule := range matches {
rule.destinations = destinations
}
log.Debugf("updated set %s to prefixes %v", set.HashedName(), destinations)
return nil
}
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size)
func (m *Manager) DropOutgoing(packetData []byte) bool {
return m.processOutgoingHooks(packetData)
}
// DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
return m.dropFilter(packetData, size)
func (m *Manager) DropIncoming(packetData []byte) bool {
return m.dropFilter(packetData)
}
// UpdateLocalIPs updates the list of local IPs
@@ -584,7 +518,10 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface)
}
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
@@ -597,34 +534,31 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
}
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
if srcIP == nil {
return false
}
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
return true
// Track all protocols if stateful mode is enabled
if m.stateful {
m.trackOutbound(d, srcIP, dstIP)
}
if m.stateful {
m.trackOutbound(d, srcIP, dstIP, size)
// Process UDP hooks even if stateful mode is disabled
if d.decoded[1] == layers.LayerTypeUDP {
return m.checkUDPHooks(d, dstIP, packetData)
}
return false
}
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
return src, dst
return d.ip4.SrcIP, d.ip4.DstIP
case layers.LayerTypeIPv6:
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
return src, dst
return d.ip6.SrcIP, d.ip6.DstIP
default:
return netip.Addr{}, netip.Addr{}
return nil, nil
}
}
@@ -651,153 +585,127 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode)
}
}
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) {
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort))
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode)
}
}
// udpHooksDrop checks if any UDP hooks should drop the packet
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
// Check specific destination IP first
if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
if rules, exists := m.outgoingRules[ipKey]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.udpHook(packetData)
}
}
}
}
// Check IPv4 unspecified address
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false
}
// dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte, size int) bool {
func (m *Manager) dropFilter(packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
valid, fragment := m.isValidPacket(d, packetData)
if !valid {
if !m.isValidPacket(d, packetData) {
return true
}
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
if srcIP == nil {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
return true
}
// TODO: pass fragments of routed packets to forwarder
if fragment {
m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
return false
}
// For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
return false
}
if m.localipmanager.IsLocalIP(dstIP) {
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
}
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData, size)
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
}
// handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
if blocked {
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
_, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
ruleId, pnum, srcAddr, srcPort, dstAddr, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeDrop,
RuleID: ruleID,
RuleID: ruleId,
Direction: nftypes.Ingress,
Protocol: pnum,
SourceIP: srcIP,
DestIP: dstIP,
SourceIP: srcAddr,
DestIP: dstAddr,
SourcePort: srcPort,
DestPort: dstPort,
// TODO: icmp type/code
RxPackets: 1,
RxBytes: uint64(size),
})
return true
}
// if running in netstack mode we need to pass this to the forwarder
if m.netstack && m.localForwarding {
if m.netstack {
return m.handleNetstackLocalTraffic(packetData)
}
// track inbound packets to get the correct direction and session id for flows
m.trackInbound(d, srcIP, dstIP, ruleID, size)
m.trackInbound(d, srcIP, dstIP)
// pass to either native or virtual stack (to be picked up by listeners)
return false
}
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
if !m.localForwarding {
// pass to virtual tcp/ip stack to be picked up by listeners
return false
}
fwd := m.forwarder.Load()
if fwd == nil {
if m.forwarder == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)")
return true
}
if err := fwd.InjectIncomingPacket(packetData); err != nil {
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err)
}
@@ -807,56 +715,47 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
// handleRoutedTraffic handles routed traffic.
// If it returns true, the packet should be dropped.
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
// Drop if routing is disabled
if !m.routingEnabled.Load() {
if !m.routingEnabled {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
return true
}
// Pass to native stack if native router is enabled or forced
if m.nativeRouter.Load() {
m.trackInbound(d, srcIP, dstIP, nil, size)
if m.nativeRouter {
return false
}
proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
if !pass {
if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
id, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeDrop,
RuleID: ruleID,
RuleID: id,
Direction: nftypes.Ingress,
Protocol: pnum,
SourceIP: srcIP,
DestIP: dstIP,
SourceIP: srcAddr,
DestIP: dstAddr,
SourcePort: srcPort,
DestPort: dstPort,
// TODO: icmp type/code
RxPackets: 1,
RxBytes: uint64(size),
})
return true
}
// Let forwarder handle the packet if it passed route ACLs
fwd := m.forwarder.Load()
if fwd == nil {
m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
} else {
fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID)
if err := fwd.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject routed packet: %v", err)
fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort)
}
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject incoming packet: %v", err)
}
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
@@ -887,35 +786,20 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
}
}
// isValidPacket checks if the packet is valid.
// It returns true, false if the packet is valid and not a fragment.
// It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Trace("couldn't decode packet, err: %s", err)
return false, false
return false
}
l := len(d.decoded)
// L3 and L4 are mandatory
if l >= 2 {
return true, false
if len(d.decoded) < 2 {
m.logger.Trace("packet doesn't have network and transport layers")
return false
}
// Fragments are also valid
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
ip4 := d.ip4
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
return true, true
}
}
m.logger.Trace("packet doesn't have network and transport layers")
return false, false
return true
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound(
@@ -924,7 +808,6 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp),
size,
)
case layers.LayerTypeUDP:
@@ -933,7 +816,6 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
size,
)
case layers.LayerTypeICMPv4:
@@ -941,8 +823,8 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr,
srcIP,
dstIP,
d.icmp4.Id,
d.icmp4.Seq,
d.icmp4.TypeCode.Type(),
size,
)
// TODO: ICMPv6
@@ -962,22 +844,20 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
icmpType == layers.ICMPv4TypeTimeExceeded
}
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) ([]byte, bool) {
if m.isSpecialICMP(d) {
return nil, false
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
return mgmtId, filter
}
@@ -1002,10 +882,10 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
return false
}
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules {
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
if rule.matchByIP && !ip.Equal(rule.ip) {
continue
}
@@ -1039,28 +919,24 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return nil, false, false
}
// routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
// routeACLsPass returns treu if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
for _, rule := range m.routeRules {
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
return rule.mgmtId, rule.action == firewall.ActionAccept
}
}
return nil, false
}
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
destMatched := false
for _, dst := range rule.destinations {
if dst.Contains(dstAddr) {
destMatched = true
break
}
}
if !destMatched {
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
if !rule.destination.Contains(dstAddr) {
return false
}
@@ -1096,7 +972,9 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
func (m *Manager) AddUDPPacketHook(
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
) string {
r := PeerRule{
id: uuid.New().String(),
ip: ip,
@@ -1106,22 +984,23 @@ func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook fu
udpHook: hook,
}
if ip.Is4() {
if ip.To4() != nil {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
if in {
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
}
m.incomingRules[r.ip][r.id] = r
m.incomingRules[r.ip.String()][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip] = make(map[string]PeerRule)
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
}
m.outgoingRules[r.ip][r.id] = r
m.outgoingRules[r.ip.String()][r.id] = r
}
m.mutex.Unlock()
return r.id
@@ -1162,52 +1041,29 @@ func (m *Manager) EnableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.determineRouting(); err != nil {
return fmt.Errorf("determine routing: %w", err)
}
if m.forwarder.Load() == nil {
return nil
}
rule, err := m.blockInvalidRouted(m.wgIface)
if err != nil {
return fmt.Errorf("block invalid routed: %w", err)
}
m.blockRule = rule
return nil
return m.determineRouting()
}
func (m *Manager) DisableRouting() error {
m.mutex.Lock()
defer m.mutex.Unlock()
fwder := m.forwarder.Load()
if fwder == nil {
if m.forwarder == nil {
return nil
}
m.routingEnabled.Store(false)
m.nativeRouter.Store(false)
m.routingEnabled = false
m.nativeRouter = false
// don't stop forwarder if in use by netstack
if m.netstack && m.localForwarding {
return nil
}
fwder.Stop()
m.forwarder.Store(nil)
m.forwarder.Stop()
m.forwarder = nil
log.Debug("forwarder stopped")
if m.blockRule != nil {
if err := m.deleteRouteRule(m.blockRule); err != nil {
return fmt.Errorf("delete block rule: %w", err)
}
m.blockRule = nil
}
return nil
}

View File

@@ -193,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection
if sc.stateful {
manager.processOutgoingHooks(outbound, 0)
manager.processOutgoingHooks(outbound)
}
// Measure inbound packet processing
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.dropFilter(inbound)
}
})
}
@@ -230,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound, 0)
manager.processOutgoingHooks(outbound)
}
// Test packet
@@ -238,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection
manager.processOutgoingHooks(testOut, 0)
manager.processOutgoingHooks(testOut)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(testIn, 0)
manager.dropFilter(testIn)
}
})
}
@@ -278,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established {
manager.processOutgoingHooks(outbound, 0)
manager.processOutgoingHooks(outbound)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.dropFilter(inbound)
}
})
}
@@ -477,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound, 0)
manager.processOutgoingHooks(outbound)
// For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" {
// SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.processOutgoingHooks(syn)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.dropFilter(synack)
// ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.processOutgoingHooks(ack)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.dropFilter(inbound)
}
})
}
@@ -624,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.processOutgoingHooks(syn)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.dropFilter(synack)
// ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.processOutgoingHooks(ack)
}
// Prepare test packets simulating bidirectional traffic
@@ -655,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic
// First outbound data
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.processOutgoingHooks(outPackets[connIdx])
// Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx])
}
})
}
@@ -761,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Connection establishment
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck)
manager.processOutgoingHooks(p.ack)
// Data transfer
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response)
// Connection teardown
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer)
manager.dropFilter(p.finServer)
manager.processOutgoingHooks(p.ackClient)
}
})
}
@@ -826,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.processOutgoingHooks(syn)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.dropFilter(synack)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.processOutgoingHooks(ack)
}
// Pre-generate test packets
@@ -856,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++
// Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
manager.processOutgoingHooks(outPackets[connIdx])
manager.dropFilter(inPackets[connIdx])
}
})
})
@@ -950,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Full connection lifecycle
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck)
manager.processOutgoingHooks(p.ack)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer)
manager.dropFilter(p.finServer)
manager.processOutgoingHooks(p.ackClient)
}
})
})
@@ -1054,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
srcIP := net.ParseIP(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
}
}

View File

@@ -12,10 +12,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
)
func TestPeerACLFiltering(t *testing.T) {
@@ -27,8 +26,8 @@ func TestPeerACLFiltering(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: localIP,
Network: wgNet,
}
@@ -189,313 +188,16 @@ func TestPeerACLFiltering(t *testing.T) {
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "Allow TCP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow UDP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "TCP packet doesn't match UDP filter with same port",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "UDP packet doesn't match TCP filter with same port",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "ICMP packet doesn't match TCP filter",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "ICMP packet doesn't match UDP filter",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "Allow TCP traffic within port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Block TCP traffic outside port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "Edge Case - Port at Range Boundary",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8100,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "UDP Port Range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 5060,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow multiple destination ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow multiple source ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
// New drop test cases
{
name: "Drop TCP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop UDP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{53}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop ICMP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop all traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolALL,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop traffic from multiple source ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop multiple destination ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop TCP traffic within port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Accept TCP traffic outside drop port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: false,
},
{
name: "Drop TCP traffic with source port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 32100,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Mixed rule - drop specific port but allow other ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
}
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.DropIncoming(packet, 0)
isDropped := manager.DropIncoming(packet)
require.True(t, isDropped, "Packet should be dropped when no rules exist")
})
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
// add general accept rule to test drop rule
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
rules, err := manager.AddPeerFiltering(
nil,
net.ParseIP("0.0.0.0"),
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept,
"",
)
require.NoError(t, err)
require.NotEmpty(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddPeerFiltering(
nil,
net.ParseIP(tc.ruleIP),
@@ -515,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.DropIncoming(packet, 0)
isDropped := manager.DropIncoming(packet)
require.Equal(t, tc.shouldBeBlocked, isDropped)
})
}
@@ -586,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: localIP,
Network: wgNet,
}
@@ -601,11 +303,11 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err)
require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter.Load())
require.True(tb, manager.routingEnabled)
require.False(tb, manager.nativeRouter)
tb.Cleanup(func() {
require.NoError(tb, manager.Close(nil))
@@ -619,7 +321,7 @@ func TestRouteACLFiltering(t *testing.T) {
type rule struct {
sources []netip.Prefix
dest fw.Network
dest netip.Prefix
proto fw.Protocol
srcPort *fw.Port
dstPort *fw.Port
@@ -645,7 +347,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept,
@@ -661,7 +363,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept,
@@ -677,7 +379,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("0.0.0.0/0"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept,
@@ -693,7 +395,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 53,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolUDP,
dstPort: &fw.Port{Values: []uint16{53}},
action: fw.ActionAccept,
@@ -707,7 +409,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("0.0.0.0/0"),
proto: fw.ProtocolICMP,
action: fw.ActionAccept,
},
@@ -722,7 +424,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolALL,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
@@ -738,7 +440,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
@@ -754,7 +456,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
@@ -770,7 +472,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
@@ -786,7 +488,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
srcPort: &fw.Port{Values: []uint16{12345}},
action: fw.ActionAccept,
@@ -805,7 +507,7 @@ func TestRouteACLFiltering(t *testing.T) {
netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
@@ -819,7 +521,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolALL,
action: fw.ActionAccept,
},
@@ -834,13 +536,33 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolALL,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: true,
},
{
name: "Multiple source networks with mismatched protocol",
srcIP: "172.16.0.1",
dstIP: "192.168.1.100",
// Should not match TCP rule
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{
netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "Allow multiple destination ports",
srcIP: "100.10.0.1",
@@ -850,7 +572,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
action: fw.ActionAccept,
@@ -866,7 +588,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
action: fw.ActionAccept,
@@ -882,7 +604,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolALL,
srcPort: &fw.Port{Values: []uint16{12345}},
dstPort: &fw.Port{Values: []uint16{80}},
@@ -899,7 +621,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{
IsRange: true,
@@ -918,7 +640,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 7999,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{
IsRange: true,
@@ -937,7 +659,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
srcPort: &fw.Port{
IsRange: true,
@@ -956,7 +678,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
srcPort: &fw.Port{
IsRange: true,
@@ -978,7 +700,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8100,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{
IsRange: true,
@@ -997,7 +719,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 5060,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolUDP,
dstPort: &fw.Port{
IsRange: true,
@@ -1016,7 +738,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolALL,
dstPort: &fw.Port{
IsRange: true,
@@ -1035,7 +757,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop,
@@ -1051,7 +773,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolALL,
action: fw.ActionDrop,
},
@@ -1069,158 +791,17 @@ func TestRouteACLFiltering(t *testing.T) {
netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionDrop,
},
shouldPass: false,
},
{
name: "Drop empty destination set",
srcIP: "172.16.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: fw.Network{Set: fw.Set{}},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "Accept TCP traffic outside drop port range",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
action: fw.ActionDrop,
},
shouldPass: true,
},
{
name: "Allow TCP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
action: fw.ActionAccept,
},
shouldPass: true,
},
{
name: "Allow UDP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
action: fw.ActionAccept,
},
shouldPass: true,
},
{
name: "TCP packet doesn't match UDP filter with same port",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "UDP packet doesn't match TCP filter with same port",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "ICMP packet doesn't match TCP filter",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "ICMP packet doesn't match UDP filter",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
action: fw.ActionAccept,
},
shouldPass: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.rule.action == fw.ActionDrop {
// add general accept rule to test drop rule
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteRouteRule(rule))
})
}
rule, err := manager.AddRouteFiltering(
nil,
tc.rule.sources,
@@ -1237,8 +818,8 @@ func TestRouteACLFiltering(t *testing.T) {
require.NoError(t, manager.DeleteRouteRule(rule))
})
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
srcIP := net.ParseIP(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP)
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// to the forwarder
@@ -1255,7 +836,7 @@ func TestRouteACLOrder(t *testing.T) {
name string
rules []struct {
sources []netip.Prefix
dest fw.Network
dest netip.Prefix
proto fw.Protocol
srcPort *fw.Port
dstPort *fw.Port
@@ -1276,7 +857,7 @@ func TestRouteACLOrder(t *testing.T) {
name: "Drop rules take precedence over accept",
rules: []struct {
sources []netip.Prefix
dest fw.Network
dest netip.Prefix
proto fw.Protocol
srcPort *fw.Port
dstPort *fw.Port
@@ -1285,7 +866,7 @@ func TestRouteACLOrder(t *testing.T) {
{
// Accept rule added first
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80, 443}},
action: fw.ActionAccept,
@@ -1293,7 +874,7 @@ func TestRouteACLOrder(t *testing.T) {
{
// Drop rule added second but should be evaluated first
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop,
@@ -1331,7 +912,7 @@ func TestRouteACLOrder(t *testing.T) {
name: "Multiple drop rules take precedence",
rules: []struct {
sources []netip.Prefix
dest fw.Network
dest netip.Prefix
proto fw.Protocol
srcPort *fw.Port
dstPort *fw.Port
@@ -1340,14 +921,14 @@ func TestRouteACLOrder(t *testing.T) {
{
// Accept all
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("0.0.0.0/0"),
proto: fw.ProtocolALL,
action: fw.ActionAccept,
},
{
// Drop specific port
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop,
@@ -1355,7 +936,7 @@ func TestRouteACLOrder(t *testing.T) {
{
// Drop different port
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionDrop,
@@ -1425,8 +1006,8 @@ func TestRouteACLOrder(t *testing.T) {
})
for i, p := range tc.packets {
srcIP := netip.MustParseAddr(p.srcIP)
dstIP := netip.MustParseAddr(p.dstIP)
srcIP := net.ParseIP(p.srcIP)
dstIP := net.ParseIP(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
@@ -1434,53 +1015,3 @@ func TestRouteACLOrder(t *testing.T) {
})
}
}
func TestRouteACLSet(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
// Add rule that uses the set (initially empty)
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err)
// Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
}

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"context"
"fmt"
"net"
"net/netip"
@@ -17,18 +18,17 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/management/domain"
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger()
type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error
AddressFunc func() wgaddr.Address
AddressFunc func() iface.WGAddress
GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice
}
@@ -54,9 +54,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
return i.SetFilterFunc(iface)
}
func (i *IFaceMock) Address() wgaddr.Address {
func (i *IFaceMock) Address() iface.WGAddress {
if i.AddressFunc == nil {
return wgaddr.Address{}
return iface.WGAddress{}
}
return i.AddressFunc()
}
@@ -125,19 +125,19 @@ func TestManagerDeleteRule(t *testing.T) {
return
}
ip := netip.MustParseAddr("192.168.1.1")
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
rule2, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@@ -151,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; ok {
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@@ -162,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name string
in bool
expDir fw.RuleDirection
ip netip.Addr
ip net.IP
dPort uint16
hook func([]byte) bool
expectedID string
@@ -171,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: netip.MustParseAddr("10.168.0.1"),
ip: net.IPv4(10, 168, 0, 1),
dPort: 8000,
hook: func([]byte) bool { return true },
},
@@ -179,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: netip.MustParseAddr("::1"),
ip: net.IPv6loopback,
dPort: 9000,
hook: func([]byte) bool { return false },
},
@@ -196,11 +196,11 @@ func TestAddUDPPacketHook(t *testing.T) {
var addedRule PeerRule
if tt.in {
if len(manager.incomingRules[tt.ip]) != 1 {
if len(manager.incomingRules[tt.ip.String()]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return
}
for _, rule := range manager.incomingRules[tt.ip] {
for _, rule := range manager.incomingRules[tt.ip.String()] {
addedRule = rule
}
} else {
@@ -208,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return
}
for _, rule := range manager.outgoingRules[tt.ip] {
for _, rule := range manager.outgoingRules[tt.ip.String()] {
addedRule = rule
}
}
if tt.ip.Compare(addedRule.ip) != 0 {
if !tt.ip.Equal(addedRule.ip) {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
@@ -269,8 +269,8 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
@@ -328,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if m.dropFilter(buf.Bytes(), 0) {
if m.dropFilter(buf.Bytes()) {
t.Errorf("expected packet to be accepted")
return
}
@@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) {
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
@@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
netip.MustParseAddr("100.10.0.100"),
net.ParseIP("100.10.0.100"),
53,
func([]byte) bool {
hookCalled = true
@@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err)
// Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes(), 0)
result := manager.processOutgoingHooks(buf.Bytes())
require.True(t, result)
require.True(t, hookCalled)
@@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes(), 0)
result = manager.processOutgoingHooks(buf.Bytes())
require.False(t, result)
}
@@ -569,11 +569,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
drop := manager.DropOutgoing(outboundBuf.Bytes())
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
@@ -636,12 +636,12 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
drop = manager.dropFilter(inboundBuf.Bytes())
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
if cp.shouldAllow {
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort)
require.True(t, exists, "Connection should still exist during valid window")
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
"LastSeen should be updated for valid responses")
@@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
// Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
@@ -707,208 +707,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), 0)
drop = manager.dropFilter(testBuf.Bytes())
require.True(t, drop, tc.description)
})
}
}
func TestUpdateSetMerge(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
initialPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.1.0/24"),
}
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
// Update the set with initial prefixes
err = manager.UpdateSet(set, initialPrefixes)
require.NoError(t, err)
// Test initial prefixes work
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP1 := netip.MustParseAddr("10.0.0.100")
dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied")
newPrefixes := []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("10.1.0.0/24"),
}
err = manager.UpdateSet(set, newPrefixes)
require.NoError(t, err)
// Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
// Check that new prefixes are included
dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
// Verify the rule has all prefixes
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
foundRule = true
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
"Rule should have all prefixes merged")
}
}
manager.mutex.RUnlock()
require.True(t, foundRule, "Rule should be found")
}
func TestUpdateSetDeduplication(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
initialPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.0/24"), // Duplicate
}
err = manager.UpdateSet(set, initialPrefixes)
require.NoError(t, err)
// Check the internal state for deduplication
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
foundRule = true
// Should have deduplicated to 2 prefixes
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
// Check the prefixes are correct
expectedPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.1.0/24"),
}
for i, prefix := range expectedPrefixes {
require.True(t, r.destinations[i] == prefix,
"Prefix should match expected value")
}
}
}
manager.mutex.RUnlock()
require.True(t, foundRule, "Rule should be found")
// Test with overlapping prefixes of different sizes
overlappingPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/16"), // More general
netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists)
netip.MustParsePrefix("192.168.0.0/16"), // More general
netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists)
}
err = manager.UpdateSet(set, overlappingPrefixes)
require.NoError(t, err)
// Check that all prefixes are included (no deduplication of overlapping prefixes)
manager.mutex.RLock()
for _, r := range manager.routeRules {
if r.id == rule.ID() {
// Should have all 4 prefixes (2 original + 2 new more general ones)
require.Len(t, r.destinations, 4,
"Overlapping prefixes should not be deduplicated")
// Verify they're sorted correctly (more specific prefixes should come first)
prefixes := make([]string, 0, len(r.destinations))
for _, p := range r.destinations {
prefixes = append(prefixes, p.String())
}
// Check sorted order
require.Equal(t, []string{
"10.0.0.0/16",
"10.0.0.0/24",
"192.168.0.0/16",
"192.168.1.0/24",
}, prefixes, "Prefixes should be sorted")
}
}
manager.mutex.RUnlock()
// Test functionality with all prefixes
testCases := []struct {
dstIP netip.Addr
expected bool
desc string
}{
{netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"},
{netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"},
{netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"},
{netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"},
{netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"},
}
srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}

View File

@@ -5,6 +5,7 @@ import (
"net"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/pion/stun/v2"
@@ -13,8 +14,6 @@ import (
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type RecvMessage struct {
@@ -53,10 +52,9 @@ type ICEBind struct {
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
@@ -66,7 +64,6 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
}
rc := receiverCreator{
@@ -111,17 +108,35 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil
}
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock()
b.endpoints[fakeIP] = conn
b.endpoints[fakeAddr] = conn
b.endpointsMu.Unlock()
return fakeUDPAddr, nil
}
func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock()
defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeIP)
delete(b.endpoints, fakeAddr)
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
@@ -146,10 +161,9 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
@@ -261,6 +275,21 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message)
}

View File

@@ -150,7 +150,7 @@ func isZeros(ip net.IP) bool {
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil {
params.Logger = getLogger()
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
mux := &UDPMuxDefault{
@@ -455,9 +455,3 @@ func newBufferHolder(size int) *bufferHolder {
buf: make([]byte, size),
}
}
func getLogger() logging.LeveledLogger {
fac := logging.NewDefaultLoggerFactory()
//fac.Writer = log.StandardLogger().Writer()
return fac.NewLogger("ice")
}

View File

@@ -17,8 +17,6 @@ import (
"github.com/pion/logging"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// FilterFn is a function that filters out candidates based on the address.
@@ -43,13 +41,12 @@ type UniversalUDPMuxParams struct {
XORMappedAddrCacheTTL time.Duration
Net transport.Net
FilterFn FilterFn
WGAddress wgaddr.Address
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil {
params.Logger = getLogger()
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
if params.XORMappedAddrCacheTTL == 0 {
params.XORMappedAddrCacheTTL = time.Second * 25
@@ -67,7 +64,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
mux: m,
logger: params.Logger,
filterFn: params.FilterFn,
address: params.WGAddress,
}
// embed UDPMux
@@ -122,7 +118,6 @@ type udpConn struct {
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
address wgaddr.Address
}
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
@@ -164,11 +159,6 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil
}
if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {

View File

@@ -357,7 +357,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
func getFwmark() int {
if nbnet.AdvancedRouting() {
return nbnet.ControlPlaneMark
return nbnet.NetbirdFwmark
}
return 0
}

View File

@@ -9,14 +9,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice

View File

@@ -1,29 +1,29 @@
package wgaddr
package device
import (
"fmt"
"net"
)
// Address WireGuard parsed address
type Address struct {
// WGAddress WireGuard parsed address
type WGAddress struct {
IP net.IP
Network *net.IPNet
}
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) {
func ParseWGAddress(address string) (WGAddress, error) {
ip, network, err := net.ParseCIDR(address)
if err != nil {
return Address{}, err
return WGAddress{}, err
}
return Address{
return WGAddress{
IP: ip,
Network: network,
}, nil
}
func (addr Address) String() string {
func (addr WGAddress) String() string {
maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
}

View File

@@ -13,12 +13,11 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct {
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -32,7 +31,7 @@ type WGTunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
@@ -94,7 +93,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
@@ -124,7 +123,7 @@ func (t *WGTunDevice) DeviceName() string {
return t.name
}
func (t *WGTunDevice) WgAddress() wgaddr.Address {
func (t *WGTunDevice) WgAddress() WGAddress {
return t.address
}

View File

@@ -13,12 +13,11 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type TunDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -30,7 +29,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -86,7 +85,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
@@ -107,7 +106,7 @@ func (t *TunDevice) Close() error {
return nil
}
func (t *TunDevice) WgAddress() wgaddr.Address {
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}

View File

@@ -2,7 +2,6 @@ package device
import (
"net"
"net/netip"
"sync"
"golang.zx2c4.com/wireguard/tun"
@@ -11,16 +10,16 @@ import (
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte, size int) bool
DropOutgoing(packetData []byte) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte, size int) bool
DropIncoming(packetData []byte) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
@@ -58,7 +57,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
}
for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
@@ -82,7 +81,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:], len(buf)) {
if !filter.DropIncoming(buf[offset:]) {
filteredBufs = append(filteredBufs, buf)
dropped++
}

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter

View File

@@ -14,12 +14,11 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type TunDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
iceBind *bind.ICEBind
@@ -31,7 +30,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -121,11 +120,11 @@ func (t *TunDevice) Close() error {
return nil
}
func (t *TunDevice) WgAddress() wgaddr.Address {
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
func (t *TunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}

View File

@@ -14,13 +14,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock"
)
type TunKernelDevice struct {
name string
address wgaddr.Address
address WGAddress
wgPort int
key string
mtu int
@@ -35,7 +34,7 @@ type TunKernelDevice struct {
filterFn bind.FilterFn
}
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{
ctx: ctx,
@@ -100,10 +99,9 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
}
mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
@@ -114,7 +112,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil
}
func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
@@ -147,7 +145,7 @@ func (t *TunKernelDevice) Close() error {
return closErr
}
func (t *TunKernelDevice) WgAddress() wgaddr.Address {
func (t *TunKernelDevice) WgAddress() WGAddress {
return t.address
}

View File

@@ -13,13 +13,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
)
type TunNetstackDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -35,7 +34,7 @@ type TunNetstackDevice struct {
net *netstack.Net
}
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
@@ -98,7 +97,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
return nil
}
@@ -117,7 +116,7 @@ func (t *TunNetstackDevice) Close() error {
return nil
}
func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
func (t *TunNetstackDevice) WgAddress() WGAddress {
return t.address
}

View File

@@ -12,12 +12,11 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type USPDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -29,7 +28,7 @@ type USPDevice struct {
configurer WGConfigurer
}
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode")
return &USPDevice{
@@ -94,7 +93,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
func (t *USPDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
@@ -114,7 +113,7 @@ func (t *USPDevice) Close() error {
return nil
}
func (t *USPDevice) WgAddress() wgaddr.Address {
func (t *USPDevice) WgAddress() WGAddress {
return t.address
}

View File

@@ -13,14 +13,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
type TunDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -33,7 +32,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -119,7 +118,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
@@ -140,7 +139,7 @@ func (t *TunDevice) Close() error {
}
return nil
}
func (t *TunDevice) WgAddress() wgaddr.Address {
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}

View File

@@ -6,7 +6,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgLink struct {
@@ -57,7 +56,7 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address wgaddr.Address) error {
func (l *wgLink) assignAddr(address WGAddress) error {
link, err := freebsd.LinkByName(l.name)
if err != nil {
return fmt.Errorf("link by name: %w", err)

View File

@@ -8,8 +8,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgLink struct {
@@ -92,7 +90,7 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address wgaddr.Address) error {
func (l *wgLink) assignAddr(address WGAddress) error {
//delete existing addresses
list, err := netlink.AddrList(l, 0)
if err != nil {

View File

@@ -7,14 +7,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice

View File

@@ -19,7 +19,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
@@ -29,6 +28,8 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault
)
type WGAddress = device.WGAddress
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
@@ -71,7 +72,7 @@ func (w *WGIface) Name() string {
}
// Address returns the interface address
func (w *WGIface) Address() wgaddr.Address {
func (w *WGIface) Address() device.WGAddress {
return w.tun.WgAddress()
}
@@ -102,7 +103,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
addr, err := wgaddr.ParseWGAddress(newAddr)
addr, err := device.ParseWGAddress(newAddr)
if err != nil {
return err
}

View File

@@ -3,18 +3,17 @@ package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
userspaceBind: true,

View File

@@ -6,18 +6,17 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -5,18 +5,17 @@ package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),

View File

@@ -8,13 +8,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
@@ -22,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
@@ -35,7 +34,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)

View File

@@ -4,17 +4,16 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -6,7 +6,6 @@ package mocks
import (
net "net"
"net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -36,7 +35,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string)
@@ -50,31 +49,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
}
// DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret := m.ctrl.Call(m, "DropIncoming", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
}
// DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
}
// RemovePacketHook mocks base method.

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"net"
"net/netip"
"strings"
"sync"
log "github.com/sirupsen/logrus"
@@ -17,13 +16,13 @@ import (
type ProxyBind struct {
Bind *bind.ICEBind
fakeNetIP *netip.AddrPort
wgBindEndpoint *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
wgAddr *net.UDPAddr
wgEndpoint *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
@@ -34,24 +33,20 @@ type ProxyBind struct {
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration.
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
fakeNetIP, err := fakeAddress(nbAddr)
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
if err != nil {
return err
}
p.fakeNetIP = fakeNetIP
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.wgAddr = addr
p.wgEndpoint = addrToEndpoint(addr)
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
return nil
return err
}
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return &net.UDPAddr{
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
return p.wgAddr
}
func (p *ProxyBind) Work() {
@@ -59,8 +54,6 @@ func (p *ProxyBind) Work() {
return
}
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
@@ -100,7 +93,7 @@ func (p *ProxyBind) close() error {
p.cancel()
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
p.Bind.RemoveEndpoint(p.wgAddr)
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
@@ -133,7 +126,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}
msg := bind.RecvMessage{
Endpoint: p.wgBindEndpoint,
Endpoint: p.wgEndpoint,
Buffer: buf[:n],
}
p.Bind.RecvChan <- msg
@@ -141,19 +134,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
if err != nil {
return nil, fmt.Errorf("failed to parse new IP: %w", err)
}
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}

View File

@@ -22,8 +22,6 @@
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
Unicode True
######################################################################
@@ -70,9 +68,6 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH
@@ -85,36 +80,8 @@ Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_LANGUAGE "English"
; Variables for autostart option
Var AutostartCheckbox
Var AutostartEnabled
######################################################################
; Function to create the autostart options page
Function AutostartPage
!insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked
StrCpy $AutostartEnabled "1" ; Default to enabled
nsDialogs::Show
FunctionEnd
; Function to handle leaving the autostart page
Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd
Function GetAppFromCommand
Exch $1
Push $2
@@ -196,16 +163,6 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
EnVar::SetHKLM
EnVar::AddValueEx "path" "$INSTDIR"
@@ -229,10 +186,7 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
ExecWait `taskkill /im ${UI_APP_EXE}.exe`
# wait the service uninstall take unblock the executable
Sleep 3000

View File

@@ -18,7 +18,7 @@ func (r RuleID) ID() string {
func GenerateRouteRuleKey(
sources []netip.Prefix,
destination manager.Network,
destination netip.Prefix,
proto manager.Protocol,
sPort *manager.Port,
dPort *manager.Port,

View File

@@ -18,7 +18,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@@ -26,12 +25,7 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty")
// Manager is a ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
}
type protoMatch struct {
ips map[string]int
policyID []byte
ApplyFiltering(networkMap *mgmProto.NetworkMap)
}
// DefaultManager uses firewall manager to handle
@@ -54,7 +48,7 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager {
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
//
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) {
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
d.mutex.Lock()
defer d.mutex.Unlock()
@@ -83,7 +77,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
}
@@ -177,16 +171,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
d.peerRulesPairs = newRulePairs
}
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules {
id, err := d.applyRouteACL(rule, dynamicResolver)
id, err := d.applyRouteACL(rule)
if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
} else {
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
}
@@ -209,7 +203,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
return nberrors.FormatErrorOrNil(merr)
}
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 {
return "", ErrSourceRangesEmpty
}
@@ -223,9 +217,15 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamic
sources = append(sources, source)
}
destination, err := determineDestination(rule, dynamicResolver, sources)
if err != nil {
return "", fmt.Errorf("determine destination: %w", err)
var destination netip.Prefix
if rule.IsDynamic {
destination = getDefault(sources[0])
} else {
var err error
destination, err = netip.ParsePrefix(rule.Destination)
if err != nil {
return "", fmt.Errorf("parse destination: %w", err)
}
}
protocol, err := convertToFirewallProtocol(rule.Protocol)
@@ -240,7 +240,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamic
dPorts := convertPortInfo(rule.PortInfo)
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
addedRule, err := d.firewall.AddRouteFiltering(rule.Id, sources, destination, protocol, nil, dPorts, action)
if err != nil {
return "", fmt.Errorf("add route rule: %w", err)
}
@@ -289,11 +289,11 @@ func (d *DefaultManager) protoRuleToFirewallRule(
var rules []firewall.Rule
switch r.Direction {
case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
rules, err = d.addInRules(r.Id, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete.
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
rules, err = d.addOutRules(r.Id, ip, protocol, port, action, ipsetName)
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
@@ -388,8 +388,10 @@ func (d *DefaultManager) squashAcceptRules(
}
}
in := map[mgmProto.RuleProtocol]*protoMatch{}
out := map[mgmProto.RuleProtocol]*protoMatch{}
type protoMatch map[mgmProto.RuleProtocol]map[string]int
in := protoMatch{}
out := protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
@@ -402,18 +404,14 @@ func (d *DefaultManager) squashAcceptRules(
// 2. Any of rule contains Port.
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
if drop {
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
protocols[r.Protocol] = map[string]int{}
return
}
if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{
ips: map[string]int{},
// store the first encountered PolicyID for this protocol
policyID: r.PolicyID,
}
protocols[r.Protocol] = map[string]int{}
}
// special case, when we receive this all network IP address
@@ -425,7 +423,7 @@ func (d *DefaultManager) squashAcceptRules(
return
}
ipset := protocols[r.Protocol].ips
ipset := protocols[r.Protocol]
if _, ok := ipset[r.PeerIP]; ok {
return
@@ -451,10 +449,9 @@ func (d *DefaultManager) squashAcceptRules(
mgmProto.RuleProtocol_UDP,
}
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
squash := func(matches protoMatch, direction mgmProto.RuleDirection) {
for _, protocol := range protocolOrders {
match, ok := matches[protocol]
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
// don't squash if :
// 1. Rules not cover all peers in the network
// 2. Rules cover only one peer in the network.
@@ -467,7 +464,6 @@ func (d *DefaultManager) squashAcceptRules(
Direction: direction,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: protocol,
PolicyID: match.policyID,
})
squashedProtocols[protocol] = struct{}{}
@@ -496,9 +492,9 @@ func (d *DefaultManager) squashAcceptRules(
// if we also have other not squashed rules.
for i, r := range networkMap.FirewallRules {
if _, ok := squashedProtocols[r.Protocol]; ok {
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
if m, ok := in[r.Protocol]; ok && m[r.PeerIP] == i {
continue
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
} else if m, ok := out[r.Protocol]; ok && m[r.PeerIP] == i {
continue
}
}
@@ -575,33 +571,6 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
return nil
}
func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) {
var destination firewall.Network
if rule.IsDynamic {
if dynamicResolver {
if len(rule.Domains) > 0 {
destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains))
} else {
// isDynamic is set but no domains = outdated management server
log.Warn("connected to an older version of management server (no domains in rules), using default destination")
destination.Prefix = getDefault(sources[0])
}
} else {
// client resolves DNS, we (router) don't know the destination
destination.Prefix = getDefault(sources[0])
}
return destination, nil
}
prefix, err := netip.ParsePrefix(rule.Destination)
if err != nil {
return destination, fmt.Errorf("parse destination: %w", err)
}
destination.Prefix = prefix
return destination, nil
}
func getDefault(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)

View File

@@ -1,6 +1,7 @@
package acl
import (
"context"
"net"
"testing"
@@ -8,13 +9,13 @@ import (
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger()
func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
@@ -48,7 +49,7 @@ func TestDefaultManager(t *testing.T) {
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
@@ -66,7 +67,7 @@ func TestDefaultManager(t *testing.T) {
acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false)
acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 2 {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
@@ -92,7 +93,7 @@ func TestDefaultManager(t *testing.T) {
},
)
acl.ApplyFiltering(networkMap, false)
acl.ApplyFiltering(networkMap)
// we should have one old and one new rule in the existed rules
if len(acl.peerRulesPairs) != 2 {
@@ -116,13 +117,13 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap, false)
acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return
@@ -342,7 +343,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
IP: ip,
Network: network,
}).AnyTimes()
@@ -359,7 +360,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}(fw)
acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap, false)
acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 3 {
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))

View File

@@ -10,8 +10,8 @@ import (
gomock "github.com/golang/mock/gomock"
wgdevice "golang.zx2c4.com/wireguard/device"
iface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// MockIFaceMapper is a mock of IFaceMapper interface.
@@ -38,10 +38,10 @@ func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
}
// Address mocks base method.
func (m *MockIFaceMapper) Address() wgaddr.Address {
func (m *MockIFaceMapper) Address() iface.WGAddress {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Address")
ret0, _ := ret[0].(wgaddr.Address)
ret0, _ := ret[0].(iface.WGAddress)
return ret0
}

View File

@@ -94,17 +94,12 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
p.codeVerifier = codeVerifier
codeChallenge := createCodeChallenge(codeVerifier)
params := []oauth2.AuthCodeOption{
authURL := p.oAuthConfig.AuthCodeURL(
state,
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
}
if !p.providerConfig.DisablePromptLogin {
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
)
return AuthFlowInfo{
VerificationURIComplete: authURL,

View File

@@ -1,49 +0,0 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
func TestPromptLogin(t *testing.T) {
tt := []struct {
name string
prompt bool
}{
{"PromptLogin", true},
{"NoPromptLogin", false},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",
Scope: "openid email profile",
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true,
DisablePromptLogin: !tc.prompt,
}
pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil {
t.Fatalf("Failed to create PKCEAuthorizationFlow: %v", err)
}
authInfo, err := pkce.RequestAuthInfo(context.Background())
if err != nil {
t.Fatalf("Failed to request auth info: %v", err)
}
pattern := "prompt=login"
if tc.prompt {
require.Contains(t, authInfo.VerificationURIComplete, pattern)
} else {
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
}
})
}
}

Some files were not shown because too many files have changed in this diff Show More