Compare commits

..

72 Commits

Author SHA1 Message Date
Viktor Liu
19178b59ec Merge branch 'main' into userspace-router 2025-02-07 13:32:09 +01:00
Viktor Liu
48f58d776c Treat the whole localhost range as local IPs 2025-01-29 22:43:14 +01:00
Viktor Liu
4d635e3c2f Merge branch 'main' into userspace-router 2025-01-29 22:00:41 +01:00
Viktor Liu
a0ca3edb9f Merge branch 'port-range-acl' into userspace-router 2025-01-23 19:29:40 +01:00
Viktor Liu
0837864cfc Fix port 2025-01-23 19:29:25 +01:00
Viktor Liu
e3d4f9819f Add test cases for port ranges in peer ACLs 2025-01-23 19:28:42 +01:00
Viktor Liu
da43d33540 Merge branch 'port-range-acl' into userspace-router 2025-01-23 19:27:45 +01:00
Viktor Liu
b951fb4aec Use uppercase field name 2025-01-23 18:14:16 +01:00
Viktor Liu
862d548d4d Support port ranges 2025-01-23 18:04:00 +01:00
Viktor Liu
9b5c0439e9 Make debug ops a bit safer 2025-01-16 20:28:38 +01:00
Viktor Liu
21a3679590 Fix regression 2025-01-15 18:27:15 +01:00
Viktor Liu
77afcc8454 Merge branch 'main' into userspace-router 2025-01-15 17:53:22 +01:00
Viktor Liu
22991b3963 Process drop rules first (#3167) 2025-01-15 17:51:08 +01:00
Viktor Liu
ea6c947f5d Merge branch 'main' into userspace-router 2025-01-15 17:01:32 +01:00
Viktor Liu
8dce13113d Disable local forwarding in netstack mode by default for security reasons 2025-01-14 18:11:02 +01:00
Viktor Liu
a625f90ea8 Merge branch 'main' into userspace-router 2025-01-14 10:49:04 +01:00
Viktor Liu
1c00870ca6 Merge branch 'allow-ssh-server-freebsd' into userspace-router 2025-01-11 12:40:32 +01:00
Viktor Liu
1296ecf96e Tidy up go mod 2025-01-11 02:38:31 +01:00
Viktor Liu
8430c37dd6 Fix panic if login cmd fails 2025-01-11 02:36:11 +01:00
Viktor Liu
648b22aca1 Fix listening in netstack mode 2025-01-11 02:36:11 +01:00
Viktor Liu
d31543cb12 Enable ssh server on freebsd 2025-01-11 02:36:11 +01:00
Viktor Liu
af46f259ac Block wg net forwarded traffic 2025-01-11 00:38:30 +01:00
Viktor Liu
01957a305d Merge branch 'main' into userspace-router 2025-01-10 10:07:57 +01:00
Viktor Liu
706f98c1f1 Improve routing decision logic 2025-01-09 17:37:12 +01:00
Viktor Liu
6335ef8b48 Correct comment 2025-01-09 14:10:33 +01:00
Viktor Liu
daf935942c Handle disable-server-routes flag in userspace router 2025-01-09 14:08:44 +01:00
Viktor Liu
28f5cd523a Merge branch 'main' into userspace-router 2025-01-08 11:38:30 +01:00
Viktor Liu
2060242092 Merge branch 'main' into userspace-router 2025-01-07 15:40:19 +01:00
Viktor Liu
5ea39dfe8a Adjust limits for iOS 2025-01-06 23:53:12 +01:00
Viktor Liu
4a189a87ce Use MTU for udp max size 2025-01-06 23:47:03 +01:00
Viktor Liu
fe7a2aa263 Fix test 2025-01-05 22:44:04 +01:00
Viktor Liu
290e6992a8 Demote close error levels 2025-01-05 14:50:17 +01:00
Viktor Liu
474fb33305 Remove established field from udp and icmp (unused) 2025-01-04 23:46:32 +01:00
Viktor Liu
766e0cccc9 Add packet tracer 2025-01-03 21:09:40 +01:00
Viktor Liu
7dfe7e426e Always use userspace routing in netstack mode 2025-01-03 18:38:57 +01:00
Viktor Liu
eaadb75144 Add env var to force userspace routing if native routing is available 2025-01-03 18:02:35 +01:00
Viktor Liu
0b116b3941 Use native firewall for nat/firewall operations if available 2025-01-03 17:52:36 +01:00
Viktor Liu
f69dd6fb62 Make extra IPs from interfaces optional 2025-01-03 16:54:46 +01:00
Viktor Liu
62a20f5f1a Add local IPs test 2025-01-03 16:50:00 +01:00
Viktor Liu
a6ad4dcf22 Close endpoint when stopping udp forwarder 2025-01-03 16:40:51 +01:00
Viktor Liu
f26b418e83 Allow to set firewall log level 2025-01-03 16:04:00 +01:00
Viktor Liu
3ce39905c6 Merge branch 'main' into userspace-router 2025-01-03 15:51:05 +01:00
Viktor Liu
979fe6bb6a Reduce complexity and fix linter issues 2025-01-03 15:43:28 +01:00
Viktor Liu
c68be6b61b Remove fractions of seconds 2025-01-03 15:18:36 +01:00
Viktor Liu
fc799effda Set log level from logrus 2025-01-03 15:16:30 +01:00
Viktor Liu
955b2b98e1 Complete route ACLs and add tests 2025-01-03 15:16:23 +01:00
Viktor Liu
9490e9095b Reduce complexity 2025-01-03 11:50:51 +01:00
Viktor Liu
d711172f67 Fix benchmarks 2025-01-03 11:30:55 +01:00
Viktor Liu
0c2fa38e26 Exclude benchmark from CI 2025-01-03 11:27:52 +01:00
Viktor Liu
88b420da6d Remove linux restriction 2025-01-03 00:23:35 +01:00
Viktor Liu
2930288f2d Fix test expectation 2025-01-03 00:22:09 +01:00
Viktor Liu
0b9854b2b1 Fix tests 2025-01-03 00:01:40 +01:00
Viktor Liu
f772a21f37 Fix log level handling 2025-01-02 19:02:40 +01:00
Viktor Liu
e912f2d7c0 Fix double close in logger 2025-01-02 19:02:40 +01:00
Viktor Liu
568d064089 Drop certain forwarded icmp packets 2025-01-02 19:02:40 +01:00
Viktor Liu
911f86ded8 Support local IPs in netstack mode 2025-01-02 19:02:40 +01:00
Viktor Liu
2b8092dfad Close endpoints 2025-01-02 16:41:54 +01:00
Viktor Liu
c3c6afa37b Merge branch 'main' into userspace-router 2025-01-02 16:25:04 +01:00
Viktor Liu
fa27369b59 Fix linter issues 2025-01-02 16:21:03 +01:00
Viktor Liu
657413b8a6 Move icmp acceptance logic 2025-01-02 15:59:53 +01:00
Viktor Liu
d85e57e819 Handle other icmp types in forwarder 2025-01-02 15:59:53 +01:00
Viktor Liu
7667886794 Add more tcp logging 2025-01-02 15:17:53 +01:00
Viktor Liu
a12a9ac290 Handle all local IPs 2025-01-02 14:59:41 +01:00
Viktor Liu
ed22d79f04 Add more control with env vars, also allow to pass traffic to native firewall 2025-01-02 13:40:36 +01:00
Viktor Liu
509b4e2132 Lower udp timeout and add teardown messages 2024-12-31 16:06:17 +01:00
Viktor Liu
fb1a10755a Fix lint and test issues 2024-12-31 14:38:59 +01:00
Viktor Liu
9feaa8d767 Add icmp forwarder 2024-12-31 12:23:16 +01:00
Viktor Liu
6a97d44d5d Improve udp implementation 2024-12-31 00:34:05 +01:00
Viktor Liu
d2616544fe Add logger 2024-12-31 00:34:05 +01:00
Viktor Liu
fad82ee65c Add stop methods and improve udp implementation 2024-12-30 14:30:53 +01:00
Viktor Liu
b43a8c56df Update wireguard-go ref 2024-12-30 13:05:26 +01:00
Viktor Liu
4199da4a45 Add userspace routing 2024-12-30 01:38:28 +01:00
453 changed files with 9944 additions and 29005 deletions

View File

@@ -31,22 +31,14 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version`
**Is any other VPN software installed?**
**NetBird status -dA output:**
If yes, which one?
If applicable, add the `netbird status -dA' command output.
**Debug output**
**Do you face any (non-mobile) client issues?**
To help us resolve the problem, please attach the following debug output
netbird status -dA
As well as the file created by
netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
**Screenshots**
@@ -55,10 +47,3 @@ If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -1,4 +1,4 @@
name: "Darwin"
name: Test Code Darwin
on:
push:
@@ -12,7 +12,9 @@ concurrency:
jobs:
test:
name: "Client / Unit"
strategy:
matrix:
store: ['sqlite']
runs-on: macos-latest
steps:
- name: Install Go

View File

@@ -1,4 +1,5 @@
name: "FreeBSD"
name: Test Code FreeBSD
on:
push:
@@ -12,7 +13,6 @@ concurrency:
jobs:
test:
name: "Client / Unit"
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4

View File

@@ -1,4 +1,4 @@
name: Linux
name: Test Code Linux
on:
push:
@@ -12,21 +12,11 @@ concurrency:
jobs:
build-cache:
name: "Build Cache"
runs-on: ubuntu-22.04
outputs:
management: ${{ steps.filter.outputs.management }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
management:
- 'management/**'
- name: Install Go
uses: actions/setup-go@v5
with:
@@ -48,6 +38,7 @@ jobs:
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
@@ -98,7 +89,6 @@ jobs:
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
test:
name: "Client / Unit"
needs: [build-cache]
strategy:
fail-fast: false
@@ -144,121 +134,14 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_relay:
name: "Relay / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
test_management:
name: "Management / Unit"
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
@@ -311,22 +194,15 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
benchmark:
name: "Management / Benchmark"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -378,21 +254,14 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
api_benchmark:
name: "Management / Benchmark (API)"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
@@ -443,25 +312,16 @@ jobs:
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management)
api_integration_test:
name: "Management / Integration"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:
@@ -503,15 +363,9 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management)
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ]
runs-on: ubuntu-20.04
steps:

View File

@@ -1,4 +1,4 @@
name: "Windows"
name: Test Code Windows
on:
push:
@@ -14,7 +14,6 @@ concurrency:
jobs:
test:
name: "Client / Unit"
runs-on: windows-latest
steps:
- name: Checkout code

View File

@@ -1,4 +1,4 @@
name: Lint
name: golangci-lint
on: [pull_request]
permissions:
@@ -27,14 +27,7 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
include:
- os: macos-latest
display_name: Darwin
- os: windows-latest
display_name: Windows
- os: ubuntu-latest
display_name: Linux
name: ${{ matrix.display_name }}
name: lint
runs-on: ${{ matrix.os }}
timeout-minutes: 15
steps:

View File

@@ -1,4 +1,4 @@
name: Mobile
name: Mobile build validation
on:
push:
@@ -12,7 +12,6 @@ concurrency:
jobs:
android_build:
name: "Android / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
@@ -48,7 +47,6 @@ jobs:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
ios_build:
name: "iOS / Build"
runs-on: macos-latest
steps:
- name: Checkout repository

View File

@@ -71,7 +71,7 @@ jobs:
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
with:
@@ -150,7 +150,7 @@ jobs:
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4

1
.gitignore vendored
View File

@@ -29,4 +29,3 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode
.DS_Store
vendor/

View File

@@ -103,7 +103,7 @@ linters:
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements
issues:
# Maximum count of issues with the same text.

View File

@@ -50,12 +50,10 @@ nfpms:
- netbird-ui
formats:
- deb
scripts:
postinstall: "release_files/ui-post-install.sh"
contents:
- src: client/ui/build/netbird.desktop
- src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/assets/netbird.png
- src: client/ui/netbird-systemtray-connected.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
@@ -69,12 +67,10 @@ nfpms:
- netbird-ui
formats:
- rpm
scripts:
postinstall: "release_files/ui-post-install.sh"
contents:
- src: client/ui/build/netbird.desktop
- src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/assets/netbird.png
- src: client/ui/netbird-systemtray-connected.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird

View File

@@ -1,6 +1,4 @@
<div align="center">
<br/>
<br/>
<p align="center">
<img width="234" src="docs/media/logo-full.png"/>
</p>
@@ -33,10 +31,6 @@
<br/>
</strong>
<br>
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
</a>
</p>
<br>

View File

@@ -1,4 +1,4 @@
FROM alpine:3.21.3
FROM alpine:3.21.0
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
)
const errCloseConnection = "Failed to close connection: %v"
@@ -86,7 +85,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
Status: getStatusOutput(cmd),
SystemInfo: debugSystemInfoFlag,
})
if err != nil {
@@ -197,7 +196,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
@@ -207,7 +206,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
@@ -272,15 +271,13 @@ func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
return nil
}
func getStatusOutput(cmd *cobra.Command, anon bool) string {
func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
)
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
}
return statusOutputString
}

View File

@@ -1,98 +0,0 @@
package cmd
import (
"fmt"
"sort"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var forwardingRulesCmd = &cobra.Command{
Use: "forwarding",
Short: "List forwarding rules",
Long: `Commands to list forwarding rules.`,
}
var forwardingRulesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List forwarding rules",
Example: " netbird forwarding list",
Long: "Commands to list forwarding rules.",
RunE: listForwardingRules,
}
func listForwardingRules(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ForwardingRules(cmd.Context(), &proto.EmptyRequest{})
if err != nil {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
if len(resp.GetRules()) == 0 {
cmd.Println("No forwarding rules available.")
return nil
}
printForwardingRules(cmd, resp.GetRules())
return nil
}
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
cmd.Println("Available forwarding rules:")
// Sort rules by translated address
sort.Slice(rules, func(i, j int) bool {
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
}
if rules[i].GetProtocol() != rules[j].GetProtocol() {
return rules[i].GetProtocol() < rules[j].GetProtocol()
}
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
})
var lastIP string
for _, rule := range rules {
dPort := portToString(rule.GetDestinationPort())
tPort := portToString(rule.GetTranslatedPort())
if lastIP != rule.GetTranslatedAddress() {
lastIP = rule.GetTranslatedAddress()
cmd.Printf("\nTranslated peer: %s\n", rule.GetTranslatedHostname())
}
cmd.Printf(" Local %s/%s to %s:%s\n", rule.GetProtocol(), dPort, rule.GetTranslatedAddress(), tPort)
}
}
func getFirstPort(portInfo *proto.PortInfo) int {
switch v := portInfo.PortSelection.(type) {
case *proto.PortInfo_Port:
return int(v.Port)
case *proto.PortInfo_Range_:
return int(v.Range.GetStart())
default:
return 0
}
}
func portToString(translatedPort *proto.PortInfo) string {
switch v := translatedPort.PortSelection.(type) {
case *proto.PortInfo_Port:
return fmt.Sprintf("%d", v.Port)
case *proto.PortInfo_Range_:
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
default:
return "No port specified"
}
}

View File

@@ -85,17 +85,11 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn)
var dnsLabelsReq []string
if dnsLabelsValidated != nil {
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {

View File

@@ -145,7 +145,6 @@ func init() {
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
@@ -154,8 +153,6 @@ func init() {
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
forwardingRulesCmd.AddCommand(forwardingRulesListCmd)
debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd)

View File

@@ -2,20 +2,107 @@ package cmd
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"runtime"
"sort"
"strings"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
type peerStateDetailOutput struct {
FQDN string `json:"fqdn" yaml:"fqdn"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
Status string `json:"status" yaml:"status"`
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
ConnType string `json:"connectionType" yaml:"connectionType"`
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
Latency time.Duration `json:"latency" yaml:"latency"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
}
type peersStateOutput struct {
Total int `json:"total" yaml:"total"`
Connected int `json:"connected" yaml:"connected"`
Details []peerStateDetailOutput `json:"details" yaml:"details"`
}
type signalStateOutput struct {
URL string `json:"url" yaml:"url"`
Connected bool `json:"connected" yaml:"connected"`
Error string `json:"error" yaml:"error"`
}
type managementStateOutput struct {
URL string `json:"url" yaml:"url"`
Connected bool `json:"connected" yaml:"connected"`
Error string `json:"error" yaml:"error"`
}
type relayStateOutputDetail struct {
URI string `json:"uri" yaml:"uri"`
Available bool `json:"available" yaml:"available"`
Error string `json:"error" yaml:"error"`
}
type relayStateOutput struct {
Total int `json:"total" yaml:"total"`
Available int `json:"available" yaml:"available"`
Details []relayStateOutputDetail `json:"details" yaml:"details"`
}
type iceCandidateType struct {
Local string `json:"local" yaml:"local"`
Remote string `json:"remote" yaml:"remote"`
}
type nsServerGroupStateOutput struct {
Servers []string `json:"servers" yaml:"servers"`
Domains []string `json:"domains" yaml:"domains"`
Enabled bool `json:"enabled" yaml:"enabled"`
Error string `json:"error" yaml:"error"`
}
type statusOutputOverview struct {
Peers peersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
ManagementState managementStateOutput `json:"management" yaml:"management"`
SignalState signalStateOutput `json:"signal" yaml:"signal"`
Relays relayStateOutput `json:"relays" yaml:"relays"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
FQDN string `json:"fqdn" yaml:"fqdn"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
}
var (
detailFlag bool
ipv4Flag bool
@@ -86,17 +173,18 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
outputInformationHolder := convertToStatusOutputOverview(resp)
var statusOutputString string
switch {
case detailFlag:
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
case jsonFlag:
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
statusOutputString, err = parseToJSON(outputInformationHolder)
case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
statusOutputString, err = parseToYAML(outputInformationHolder)
default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
}
if err != nil {
@@ -126,6 +214,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
func parseFilters() error {
switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected":
if strings.ToLower(statusFilter) != "" {
@@ -162,6 +251,175 @@ func enableDetailFlagWhenFilterFlag() {
}
}
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
pbFullStatus := resp.GetFullStatus()
managementState := pbFullStatus.GetManagementState()
managementOverview := managementStateOutput{
URL: managementState.GetURL(),
Connected: managementState.GetConnected(),
Error: managementState.Error,
}
signalState := pbFullStatus.GetSignalState()
signalOverview := signalStateOutput{
URL: signalState.GetURL(),
Connected: signalState.GetConnected(),
Error: signalState.Error,
}
relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
overview := statusOutputOverview{
Peers: peersOverview,
CliVersion: version.NetbirdVersion(),
DaemonVersion: resp.GetDaemonVersion(),
ManagementState: managementOverview,
SignalState: signalOverview,
Relays: relayOverview,
IP: pbFullStatus.GetLocalPeerState().GetIP(),
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
}
if anonymizeFlag {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
anonymizeOverview(anonymizer, &overview)
}
return overview
}
func mapRelays(relays []*proto.RelayState) relayStateOutput {
var relayStateDetail []relayStateOutputDetail
var relaysAvailable int
for _, relay := range relays {
available := relay.GetAvailable()
relayStateDetail = append(relayStateDetail,
relayStateOutputDetail{
URI: relay.URI,
Available: available,
Error: relay.GetError(),
},
)
if available {
relaysAvailable++
}
}
return relayStateOutput{
Total: len(relays),
Available: relaysAvailable,
Details: relayStateDetail,
}
}
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
for _, pbNsGroupServer := range servers {
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
Servers: pbNsGroupServer.GetServers(),
Domains: pbNsGroupServer.GetDomains(),
Enabled: pbNsGroupServer.GetEnabled(),
Error: pbNsGroupServer.GetError(),
})
}
return mappedNSGroups
}
func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput
peersConnected := 0
for _, pbPeerState := range peers {
localICE := ""
remoteICE := ""
localICEEndpoint := ""
remoteICEEndpoint := ""
relayServerAddress := ""
connType := ""
lastHandshake := time.Time{}
transferReceived := int64(0)
transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected) {
continue
}
if isPeerConnected {
peersConnected++
localICE = pbPeerState.GetLocalIceCandidateType()
remoteICE = pbPeerState.GetRemoteIceCandidateType()
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
connType = "P2P"
if pbPeerState.Relayed {
connType = "Relayed"
}
relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx()
transferSent = pbPeerState.GetBytesTx()
}
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
peerState := peerStateDetailOutput{
IP: pbPeerState.GetIP(),
PubKey: pbPeerState.GetPubKey(),
Status: pbPeerState.GetConnStatus(),
LastStatusUpdate: timeLocal,
ConnType: connType,
IceCandidateType: iceCandidateType{
Local: localICE,
Remote: remoteICE,
},
IceCandidateEndpoint: iceCandidateType{
Local: localICEEndpoint,
Remote: remoteICEEndpoint,
},
RelayAddress: relayServerAddress,
FQDN: pbPeerState.GetFqdn(),
LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived,
TransferSent: transferSent,
Latency: pbPeerState.GetLatency().AsDuration(),
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
Routes: pbPeerState.GetNetworks(),
Networks: pbPeerState.GetNetworks(),
}
peersStateDetail = append(peersStateDetail, peerState)
}
sortPeersByIP(peersStateDetail)
peersOverview := peersStateOutput{
Total: len(peersStateDetail),
Connected: peersConnected,
Details: peersStateDetail,
}
return peersOverview
}
func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
if len(peersStateDetail) > 0 {
sort.SliceStable(peersStateDetail, func(i, j int) bool {
iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
return iAddr.Compare(jAddr) == -1
})
}
}
func parseInterfaceIP(interfaceIP string) string {
ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil {
@@ -169,3 +427,452 @@ func parseInterfaceIP(interfaceIP string) string {
}
return fmt.Sprintf("%s\n", ip)
}
func parseToJSON(overview statusOutputOverview) (string, error) {
jsonBytes, err := json.Marshal(overview)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
func parseToYAML(overview statusOutputOverview) (string, error) {
yamlBytes, err := yaml.Marshal(overview)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
var managementConnString string
if overview.ManagementState.Connected {
managementConnString = "Connected"
if showURL {
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
}
} else {
managementConnString = "Disconnected"
if overview.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
}
}
var signalConnString string
if overview.SignalState.Connected {
signalConnString = "Connected"
if showURL {
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
}
} else {
signalConnString = "Disconnected"
if overview.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
}
}
interfaceTypeString := "Userspace"
interfaceIP := overview.IP
if overview.KernelInterface {
interfaceTypeString = "Kernel"
} else if overview.IP == "" {
interfaceTypeString = "N/A"
interfaceIP = "N/A"
}
var relaysString string
if showRelays {
for _, relay := range overview.Relays.Details {
available := "Available"
reason := ""
if !relay.Available {
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
}
networks := "-"
if len(overview.Networks) > 0 {
sort.Strings(overview.Networks)
networks = strings.Join(overview.Networks, ", ")
}
var dnsServersString string
if showNameServers {
for _, nsServerGroup := range overview.NSServerGroups {
enabled := "Available"
if !nsServerGroup.Enabled {
enabled = "Unavailable"
}
errorString := ""
if nsServerGroup.Error != "" {
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
errorString = strings.TrimSpace(errorString)
}
domainsString := strings.Join(nsServerGroup.Domains, ", ")
if domainsString == "" {
domainsString = "." // Show "." for the default zone
}
dnsServersString += fmt.Sprintf(
"\n [%s] for [%s] is %s%s",
strings.Join(nsServerGroup.Servers, ", "),
domainsString,
enabled,
errorString,
)
}
} else {
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
}
rosenpassEnabledStatus := "false"
if overview.RosenpassEnabled {
rosenpassEnabledStatus = "true"
if overview.RosenpassPermissive {
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
goarm := ""
if goarch == "arm" {
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
summary := fmt.Sprintf(
"OS: %s\n"+
"Daemon version: %s\n"+
"CLI version: %s\n"+
"Management: %s\n"+
"Signal: %s\n"+
"Relays: %s\n"+
"Nameservers: %s\n"+
"FQDN: %s\n"+
"NetBird IP: %s\n"+
"Interface type: %s\n"+
"Quantum resistance: %s\n"+
"Routes: %s\n"+
"Networks: %s\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
version.NetbirdVersion(),
managementConnString,
signalConnString,
relaysString,
dnsServersString,
overview.FQDN,
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
networks,
networks,
peersCountString,
)
return summary
}
func parseToFullDetailSummary(overview statusOutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
summary := parseGeneralSummary(overview, true, true, true)
return fmt.Sprintf(
"Peers detail:"+
"%s\n"+
"%s",
parsedPeersString,
summary,
)
}
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
var (
peersString = ""
)
for _, peerState := range peers.Details {
localICE := "-"
if peerState.IceCandidateType.Local != "" {
localICE = peerState.IceCandidateType.Local
}
remoteICE := "-"
if peerState.IceCandidateType.Remote != "" {
remoteICE = peerState.IceCandidateType.Remote
}
localICEEndpoint := "-"
if peerState.IceCandidateEndpoint.Local != "" {
localICEEndpoint = peerState.IceCandidateEndpoint.Local
}
remoteICEEndpoint := "-"
if peerState.IceCandidateEndpoint.Remote != "" {
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
}
rosenpassEnabledStatus := "false"
if rosenpassEnabled {
if peerState.RosenpassEnabled {
rosenpassEnabledStatus = "true"
} else {
if rosenpassPermissive {
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
} else {
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
}
}
} else {
if peerState.RosenpassEnabled {
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
}
}
networks := "-"
if len(peerState.Networks) > 0 {
sort.Strings(peerState.Networks)
networks = strings.Join(peerState.Networks, ", ")
}
peerString := fmt.Sprintf(
"\n %s:\n"+
" NetBird IP: %s\n"+
" Public key: %s\n"+
" Status: %s\n"+
" -- detail --\n"+
" Connection type: %s\n"+
" ICE candidate (Local/Remote): %s/%s\n"+
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
" Relay server address: %s\n"+
" Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+
" Quantum resistance: %s\n"+
" Routes: %s\n"+
" Networks: %s\n"+
" Latency: %s\n",
peerState.FQDN,
peerState.IP,
peerState.PubKey,
peerState.Status,
peerState.ConnType,
localICE,
remoteICE,
localICEEndpoint,
remoteICEEndpoint,
peerState.RelayAddress,
timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent),
rosenpassEnabledStatus,
networks,
networks,
peerState.Latency.String(),
)
peersString += peerString
}
return peersString
}
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false
ipEval := false
nameEval := true
if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter)
if lowerStatusFilter == "disconnected" && isConnected {
statusEval = true
} else if lowerStatusFilter == "connected" && !isConnected {
statusEval = true
}
}
if len(ipsFilter) > 0 {
_, ok := ipsFilterMap[peerState.IP]
if !ok {
ipEval = true
}
}
if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap {
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = false
break
}
}
} else {
nameEval = false
}
return statusEval || ipEval || nameEval
}
func toIEC(b int64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB",
float64(b)/float64(div), "KMGTPE"[exp])
}
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
count := 0
for _, server := range dnsServers {
if server.Enabled {
count++
}
}
return count
}
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
func timeAgo(t time.Time) string {
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
return "-"
}
duration := time.Since(t)
switch {
case duration < time.Second:
return "Now"
case duration < time.Minute:
seconds := int(duration.Seconds())
if seconds == 1 {
return "1 second ago"
}
return fmt.Sprintf("%d seconds ago", seconds)
case duration < time.Hour:
minutes := int(duration.Minutes())
seconds := int(duration.Seconds()) % 60
if minutes == 1 {
if seconds == 1 {
return "1 minute, 1 second ago"
} else if seconds > 0 {
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
}
return "1 minute ago"
}
if seconds > 0 {
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
}
return fmt.Sprintf("%d minutes ago", minutes)
case duration < 24*time.Hour:
hours := int(duration.Hours())
minutes := int(duration.Minutes()) % 60
if hours == 1 {
if minutes == 1 {
return "1 hour, 1 minute ago"
} else if minutes > 0 {
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
}
return "1 hour ago"
}
if minutes > 0 {
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
}
return fmt.Sprintf("%d hours ago", hours)
}
days := int(duration.Hours()) / 24
hours := int(duration.Hours()) % 24
if days == 1 {
if hours == 1 {
return "1 day, 1 hour ago"
} else if hours > 0 {
return fmt.Sprintf("1 day, %d hours ago", hours)
}
return "1 day ago"
}
if hours > 0 {
return fmt.Sprintf("%d days, %d hours ago", days, hours)
}
return fmt.Sprintf("%d days ago", days)
}
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
}
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeRoute(route)
}
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
anonymizePeerDetail(a, &peer)
overview.Peers.Details[i] = peer
}
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
overview.IP = a.AnonymizeIPString(overview.IP)
for i, detail := range overview.Relays.Details {
detail.URI = a.AnonymizeURI(detail.URI)
detail.Error = a.AnonymizeString(detail.Error)
overview.Relays.Details[i] = detail
}
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
}
for j, ns := range nsGroup.Servers {
host, port, err := net.SplitHostPort(ns)
if err == nil {
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
}
}
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range overview.Routes {
overview.Routes[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}

View File

@@ -1,11 +1,597 @@
package cmd
import (
"bytes"
"encoding/json"
"fmt"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
func init() {
loc, err := time.LoadLocation("UTC")
if err != nil {
panic(err)
}
time.Local = loc
}
var resp = &proto.StatusResponse{
Status: "Connected",
FullStatus: &proto.FullStatus{
Peers: []*proto.PeerState{
{
IP: "192.168.178.101",
PubKey: "Pubkey1",
Fqdn: "peer-1.awesome-domain.com",
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
Relayed: false,
LocalIceCandidateType: "",
RemoteIceCandidateType: "",
LocalIceCandidateEndpoint: "",
RemoteIceCandidateEndpoint: "",
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
BytesRx: 200,
BytesTx: 100,
Networks: []string{
"10.1.0.0/24",
},
Latency: durationpb.New(time.Duration(10000000)),
},
{
IP: "192.168.178.102",
PubKey: "Pubkey2",
Fqdn: "peer-2.awesome-domain.com",
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
Relayed: true,
LocalIceCandidateType: "relay",
RemoteIceCandidateType: "prflx",
LocalIceCandidateEndpoint: "10.0.0.1:10001",
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
BytesRx: 2000,
BytesTx: 1000,
Latency: durationpb.New(time.Duration(10000000)),
},
},
ManagementState: &proto.ManagementState{
URL: "my-awesome-management.com:443",
Connected: true,
Error: "",
},
SignalState: &proto.SignalState{
URL: "my-awesome-signal.com:443",
Connected: true,
Error: "",
},
Relays: []*proto.RelayState{
{
URI: "stun:my-awesome-stun.com:3478",
Available: true,
Error: "",
},
{
URI: "turns:my-awesome-turn.com:443?transport=tcp",
Available: false,
Error: "context: deadline exceeded",
},
},
LocalPeerState: &proto.LocalPeerState{
IP: "192.168.178.100/16",
PubKey: "Some-Pub-Key",
KernelInterface: true,
Fqdn: "some-localhost.awesome-domain.com",
Networks: []string{
"10.10.0.0/24",
},
},
DnsServers: []*proto.NSGroupState{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
},
DaemonVersion: "0.14.1",
}
var overview = statusOutputOverview{
Peers: peersStateOutput{
Total: 2,
Connected: 2,
Details: []peerStateDetailOutput{
{
IP: "192.168.178.101",
PubKey: "Pubkey1",
FQDN: "peer-1.awesome-domain.com",
Status: "Connected",
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
ConnType: "P2P",
IceCandidateType: iceCandidateType{
Local: "",
Remote: "",
},
IceCandidateEndpoint: iceCandidateType{
Local: "",
Remote: "",
},
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
TransferReceived: 200,
TransferSent: 100,
Routes: []string{
"10.1.0.0/24",
},
Networks: []string{
"10.1.0.0/24",
},
Latency: time.Duration(10000000),
},
{
IP: "192.168.178.102",
PubKey: "Pubkey2",
FQDN: "peer-2.awesome-domain.com",
Status: "Connected",
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
ConnType: "Relayed",
IceCandidateType: iceCandidateType{
Local: "relay",
Remote: "prflx",
},
IceCandidateEndpoint: iceCandidateType{
Local: "10.0.0.1:10001",
Remote: "10.0.10.1:10002",
},
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
TransferReceived: 2000,
TransferSent: 1000,
Latency: time.Duration(10000000),
},
},
},
CliVersion: version.NetbirdVersion(),
DaemonVersion: "0.14.1",
ManagementState: managementStateOutput{
URL: "my-awesome-management.com:443",
Connected: true,
Error: "",
},
SignalState: signalStateOutput{
URL: "my-awesome-signal.com:443",
Connected: true,
Error: "",
},
Relays: relayStateOutput{
Total: 2,
Available: 1,
Details: []relayStateOutputDetail{
{
URI: "stun:my-awesome-stun.com:3478",
Available: true,
Error: "",
},
{
URI: "turns:my-awesome-turn.com:443?transport=tcp",
Available: false,
Error: "context: deadline exceeded",
},
},
},
IP: "192.168.178.100/16",
PubKey: "Some-Pub-Key",
KernelInterface: true,
FQDN: "some-localhost.awesome-domain.com",
NSServerGroups: []nsServerGroupStateOutput{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
Routes: []string{
"10.10.0.0/24",
},
Networks: []string{
"10.10.0.0/24",
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := convertToStatusOutputOverview(resp)
assert.Equal(t, overview, convertedResult)
}
func TestSortingOfPeers(t *testing.T) {
peers := []peerStateDetailOutput{
{
IP: "192.168.178.104",
},
{
IP: "192.168.178.102",
},
{
IP: "192.168.178.101",
},
{
IP: "192.168.178.105",
},
{
IP: "192.168.178.103",
},
}
sortPeersByIP(peers)
assert.Equal(t, peers[3].IP, "192.168.178.104")
}
func TestParsingToJSON(t *testing.T) {
jsonString, _ := parseToJSON(overview)
//@formatter:off
expectedJSONString := `
{
"peers": {
"total": 2,
"connected": 2,
"details": [
{
"fqdn": "peer-1.awesome-domain.com",
"netbirdIp": "192.168.178.101",
"publicKey": "Pubkey1",
"status": "Connected",
"lastStatusUpdate": "2001-01-01T01:01:01Z",
"connectionType": "P2P",
"iceCandidateType": {
"local": "",
"remote": ""
},
"iceCandidateEndpoint": {
"local": "",
"remote": ""
},
"relayAddress": "",
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200,
"transferSent": 100,
"latency": 10000000,
"quantumResistance": false,
"routes": [
"10.1.0.0/24"
],
"networks": [
"10.1.0.0/24"
]
},
{
"fqdn": "peer-2.awesome-domain.com",
"netbirdIp": "192.168.178.102",
"publicKey": "Pubkey2",
"status": "Connected",
"lastStatusUpdate": "2002-02-02T02:02:02Z",
"connectionType": "Relayed",
"iceCandidateType": {
"local": "relay",
"remote": "prflx"
},
"iceCandidateEndpoint": {
"local": "10.0.0.1:10001",
"remote": "10.0.10.1:10002"
},
"relayAddress": "",
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000,
"transferSent": 1000,
"latency": 10000000,
"quantumResistance": false,
"routes": null,
"networks": null
}
]
},
"cliVersion": "development",
"daemonVersion": "0.14.1",
"management": {
"url": "my-awesome-management.com:443",
"connected": true,
"error": ""
},
"signal": {
"url": "my-awesome-signal.com:443",
"connected": true,
"error": ""
},
"relays": {
"total": 2,
"available": 1,
"details": [
{
"uri": "stun:my-awesome-stun.com:3478",
"available": true,
"error": ""
},
{
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
"available": false,
"error": "context: deadline exceeded"
}
]
},
"netbirdIp": "192.168.178.100/16",
"publicKey": "Some-Pub-Key",
"usesKernelInterface": true,
"fqdn": "some-localhost.awesome-domain.com",
"quantumResistance": false,
"quantumResistancePermissive": false,
"routes": [
"10.10.0.0/24"
],
"networks": [
"10.10.0.0/24"
],
"dnsServers": [
{
"servers": [
"8.8.8.8:53"
],
"domains": null,
"enabled": true,
"error": ""
},
{
"servers": [
"1.1.1.1:53",
"2.2.2.2:53"
],
"domains": [
"example.com",
"example.net"
],
"enabled": false,
"error": "timeout"
}
]
}`
// @formatter:on
var expectedJSON bytes.Buffer
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
assert.Equal(t, expectedJSON.String(), jsonString)
}
func TestParsingToYAML(t *testing.T) {
yaml, _ := parseToYAML(overview)
expectedYAML :=
`peers:
total: 2
connected: 2
details:
- fqdn: peer-1.awesome-domain.com
netbirdIp: 192.168.178.101
publicKey: Pubkey1
status: Connected
lastStatusUpdate: 2001-01-01T01:01:01Z
connectionType: P2P
iceCandidateType:
local: ""
remote: ""
iceCandidateEndpoint:
local: ""
remote: ""
relayAddress: ""
lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200
transferSent: 100
latency: 10ms
quantumResistance: false
routes:
- 10.1.0.0/24
networks:
- 10.1.0.0/24
- fqdn: peer-2.awesome-domain.com
netbirdIp: 192.168.178.102
publicKey: Pubkey2
status: Connected
lastStatusUpdate: 2002-02-02T02:02:02Z
connectionType: Relayed
iceCandidateType:
local: relay
remote: prflx
iceCandidateEndpoint:
local: 10.0.0.1:10001
remote: 10.0.10.1:10002
relayAddress: ""
lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000
transferSent: 1000
latency: 10ms
quantumResistance: false
routes: []
networks: []
cliVersion: development
daemonVersion: 0.14.1
management:
url: my-awesome-management.com:443
connected: true
error: ""
signal:
url: my-awesome-signal.com:443
connected: true
error: ""
relays:
total: 2
available: 1
details:
- uri: stun:my-awesome-stun.com:3478
available: true
error: ""
- uri: turns:my-awesome-turn.com:443?transport=tcp
available: false
error: 'context: deadline exceeded'
netbirdIp: 192.168.178.100/16
publicKey: Some-Pub-Key
usesKernelInterface: true
fqdn: some-localhost.awesome-domain.com
quantumResistance: false
quantumResistancePermissive: false
routes:
- 10.10.0.0/24
networks:
- 10.10.0.0/24
dnsServers:
- servers:
- 8.8.8.8:53
domains: []
enabled: true
error: ""
- servers:
- 1.1.1.1:53
- 2.2.2.2:53
domains:
- example.com
- example.net
enabled: false
error: timeout
`
assert.Equal(t, expectedYAML, yaml)
}
func TestParsingToDetail(t *testing.T) {
// Calculate time ago based on the fixture dates
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := parseToFullDetailSummary(overview)
expectedDetail := fmt.Sprintf(
`Peers detail:
peer-1.awesome-domain.com:
NetBird IP: 192.168.178.101
Public key: Pubkey1
Status: Connected
-- detail --
Connection type: P2P
ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/-
Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B
Quantum resistance: false
Routes: 10.1.0.0/24
Networks: 10.1.0.0/24
Latency: 10ms
peer-2.awesome-domain.com:
NetBird IP: 192.168.178.102
Public key: Pubkey2
Status: Connected
-- detail --
Connection type: Relayed
ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false
Routes: -
Networks: -
Latency: 10ms
OS: %s/%s
Daemon version: 0.14.1
CLI version: %s
Management: Connected to my-awesome-management.com:443
Signal: Connected to my-awesome-signal.com:443
Relays:
[stun:my-awesome-stun.com:3478] is Available
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
Nameservers:
[8.8.8.8:53] for [.] is Available
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
assert.Equal(t, expectedDetail, detail)
}
func TestParsingToShortVersion(t *testing.T) {
shortVersion := parseGeneralSummary(overview, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1
CLI version: development
Management: Connected
Signal: Connected
Relays: 1/2 Available
Nameservers: 1/2 Available
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`
assert.Equal(t, expectedString, shortVersion)
}
func TestParsingOfIP(t *testing.T) {
InterfaceIP := "192.168.178.123/16"
@@ -13,3 +599,31 @@ func TestParsingOfIP(t *testing.T) {
assert.Equal(t, "192.168.178.123\n", parsedIP)
}
func TestTimeAgo(t *testing.T) {
now := time.Now()
cases := []struct {
name string
input time.Time
expected string
}{
{"Now", now, "Now"},
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
{"Zero time", time.Time{}, "-"},
{"Unix zero time", time.Unix(0, 0), "-"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := timeAgo(tc.input)
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
})
}
}

View File

@@ -10,7 +10,6 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -90,13 +89,13 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
if err != nil {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -20,7 +20,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/util"
)
@@ -30,16 +29,9 @@ const (
interfaceInputType
)
const (
dnsLabelsFlag = "extra-dns-labels"
)
var (
foregroundMode bool
dnsLabels []string
dnsLabelsValidated domain.List
upCmd = &cobra.Command{
foregroundMode bool
upCmd = &cobra.Command{
Use: "up",
Short: "install, login and start Netbird client",
RunE: upFunc,
@@ -57,14 +49,6 @@ func init() {
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+
`You can specify a comma-separated list of up to 32 labels. `+
`An empty string "" clears the previous configuration. `+
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`,
)
}
func upFunc(cmd *cobra.Command, args []string) error {
@@ -83,11 +67,6 @@ func upFunc(cmd *cobra.Command, args []string) error {
return err
}
dnsLabelsValidated, err = validateDnsLabels(dnsLabels)
if err != nil {
return err
}
ctx := internal.CtxInitState(cmd.Context())
if hostName != "" {
@@ -119,7 +98,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
@@ -262,8 +240,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -454,24 +430,6 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
return parsed, nil
}
func validateDnsLabels(labels []string) (domain.List, error) {
var (
domains domain.List
err error
)
if len(labels) == 0 {
return domains, nil
}
domains, err = domain.ValidateDomains(labels)
if err != nil {
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
}
return domains, nil
}
func isValidAddrPort(input string) bool {
if input == "" {
return true

View File

@@ -1,167 +0,0 @@
// Package embed provides a way to embed the NetBird client directly
// into Go programs without requiring a separate NetBird client installation.
package embed
// Basic Usage:
//
// client, err := embed.New(embed.Options{
// DeviceName: "my-service",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// })
// if err != nil {
// log.Fatal(err)
// }
//
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// Complete HTTP Server Example:
//
// package main
//
// import (
// "context"
// "fmt"
// "log"
// "net/http"
// "os"
// "os/signal"
// "syscall"
// "time"
//
// netbird "github.com/netbirdio/netbird/client/embed"
// )
//
// func main() {
// // Create client with setup key and device name
// client, err := netbird.New(netbird.Options{
// DeviceName: "http-server",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// LogOutput: io.Discard,
// })
// if err != nil {
// log.Fatal(err)
// }
//
// // Start with timeout
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// // Create HTTP server
// mux := http.NewServeMux()
// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// fmt.Printf("Request from %s: %s %s\n", r.RemoteAddr, r.Method, r.URL.Path)
// fmt.Fprintf(w, "Hello from netbird!")
// })
//
// // Listen on netbird network
// l, err := client.ListenTCP(":8080")
// if err != nil {
// log.Fatal(err)
// }
//
// server := &http.Server{Handler: mux}
// go func() {
// if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) {
// log.Printf("HTTP server error: %v", err)
// }
// }()
//
// log.Printf("HTTP server listening on netbird network port 8080")
//
// // Handle shutdown
// stop := make(chan os.Signal, 1)
// signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
// <-stop
//
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// defer cancel()
//
// if err := server.Shutdown(shutdownCtx); err != nil {
// log.Printf("HTTP shutdown error: %v", err)
// }
// if err := client.Stop(shutdownCtx); err != nil {
// log.Printf("Netbird shutdown error: %v", err)
// }
// }
//
// Complete HTTP Client Example:
//
// package main
//
// import (
// "context"
// "fmt"
// "io"
// "log"
// "os"
// "time"
//
// netbird "github.com/netbirdio/netbird/client/embed"
// )
//
// func main() {
// // Create client with setup key and device name
// client, err := netbird.New(netbird.Options{
// DeviceName: "http-client",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// LogOutput: io.Discard,
// })
// if err != nil {
// log.Fatal(err)
// }
//
// // Start with timeout
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
//
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// // Create HTTP client that uses netbird network
// httpClient := client.NewHTTPClient()
// httpClient.Timeout = 10 * time.Second
//
// // Make request to server in netbird network
// target := os.Getenv("NB_TARGET")
// resp, err := httpClient.Get(target)
// if err != nil {
// log.Fatal(err)
// }
// defer resp.Body.Close()
//
// // Read and print response
// body, err := io.ReadAll(resp.Body)
// if err != nil {
// log.Fatal(err)
// }
//
// fmt.Printf("Response from server: %s\n", string(body))
//
// // Clean shutdown
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// defer cancel()
//
// if err := client.Stop(shutdownCtx); err != nil {
// log.Printf("Netbird shutdown error: %v", err)
// }
// }
//
// The package provides several methods for network operations:
// - Dial: Creates outbound connections
// - ListenTCP: Creates TCP listeners
// - ListenUDP: Creates UDP listeners
//
// By default, the embed package uses userspace networking mode, which doesn't
// require root/admin privileges. For production deployments, consider setting
// appropriate config and state paths for persistence.

View File

@@ -1,293 +0,0 @@
package embed
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
"sync"
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system"
)
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
// Client manages a netbird embedded client instance
type Client struct {
deviceName string
config *internal.Config
mu sync.Mutex
cancel context.CancelFunc
setupKey string
connect *internal.ConnectClient
}
// Options configures a new Client
type Options struct {
// DeviceName is this peer's name in the network
DeviceName string
// SetupKey is used for authentication
SetupKey string
// ManagementURL overrides the default management server URL
ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface
PreSharedKey string
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
LogOutput io.Writer
// LogLevel sets the logging level (defaults to info if empty)
LogLevel string
// NoUserspace disables the userspace networking mode. Needs admin/root privileges
NoUserspace bool
// ConfigPath is the path to the netbird config file. If empty, the config will be stored in memory and not persisted.
ConfigPath string
// StatePath is the path to the netbird state file
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
}
// New creates a new netbird embedded client
func New(opts Options) (*Client, error) {
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
if opts.LogLevel != "" {
level, err := logrus.ParseLevel(opts.LogLevel)
if err != nil {
return nil, fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
}
if !opts.NoUserspace {
if err := os.Setenv(netstack.EnvUseNetstackMode, "true"); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
if err := os.Setenv(netstack.EnvSkipProxy, "true"); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
}
if opts.StatePath != "" {
// TODO: Disable state if path not provided
if err := os.Setenv("NB_DNS_STATE_FILE", opts.StatePath); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
}
t := true
var config *internal.Config
var err error
input := internal.ConfigInput{
ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL,
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
}
if opts.ConfigPath != "" {
config, err = internal.UpdateOrCreateConfig(input)
} else {
config, err = internal.CreateInMemoryConfig(input)
}
if err != nil {
return nil, fmt.Errorf("create config: %w", err)
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
config: config,
}, nil
}
// Start begins client operation and blocks until the engine has been started successfully or a startup error occurs.
// Pass a context with a deadline to limit the time spent waiting for the engine to start.
func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cancel != nil {
return ErrClientAlreadyStarted
}
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
return fmt.Errorf("login: %w", err)
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
clientErr <- err
}
}()
select {
case <-startCtx.Done():
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}
return startCtx.Err()
case err := <-clientErr:
return fmt.Errorf("startup: %w", err)
case <-run:
}
c.connect = client
return nil
}
// Stop gracefully stops the client.
// Pass a context with a deadline to limit the time spent waiting for the engine to stop.
func (c *Client) Stop(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connect == nil {
return ErrClientNotStarted
}
done := make(chan error, 1)
go func() {
done <- c.connect.Stop()
}()
select {
case <-ctx.Done():
c.cancel = nil
return ctx.Err()
case err := <-done:
c.cancel = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
return nil
}
}
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, ErrClientNotStarted
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, errors.New("engine not started")
}
nsnet, err := engine.GetNet()
if err != nil {
return nil, fmt.Errorf("get net: %w", err)
}
return nsnet.DialContext(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet()
if err != nil {
return nil, err
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, fmt.Errorf("resolve: %w", err)
}
return nsnet.ListenTCP(tcpAddr)
}
// ListenUDP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet()
if err != nil {
return nil, err
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, fmt.Errorf("resolve: %w", err)
}
return nsnet.ListenUDP(udpAddr)
}
// NewHTTPClient returns a configured http.Client that uses the netbird network for requests.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) NewHTTPClient() *http.Client {
transport := &http.Transport{
DialContext: c.Dial,
}
return &http.Client{
Transport: transport,
}
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, netip.Addr{}, errors.New("client not started")
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, netip.Addr{}, errors.New("engine not started")
}
addr, err := engine.Address()
if err != nil {
return nil, netip.Addr{}, fmt.Errorf("engine address: %w", err)
}
nsnet, err := engine.GetNet()
if err != nil {
return nil, netip.Addr{}, fmt.Errorf("get net: %w", err)
}
return nsnet, addr, nil
}

View File

@@ -10,18 +10,17 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
fm, err := uspfilter.Create(iface, disableServerRoutes)
if err != nil {
return nil, err
}

View File

@@ -15,7 +15,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -34,7 +33,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
@@ -48,7 +47,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
return createUserspaceFirewall(iface, fm, disableServerRoutes)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
@@ -78,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
} else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
}
if errUsp != nil {

View File

@@ -4,13 +4,12 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() wgaddr.Address
Address() device.WGAddress
IsUserspaceBind() bool
SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice

View File

@@ -30,8 +30,10 @@ type entry struct {
}
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routingFwChainName string
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
@@ -39,10 +41,12 @@ type aclManager struct {
stateManager *statemanager.Manager
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
m := &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
iptablesClient: iptablesClient,
wgIface: wgIface,
routingFwChainName: routingFwChainName,
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
@@ -75,7 +79,6 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
}
func (m *aclManager) AddPeerFiltering(
id []byte,
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
@@ -311,12 +314,9 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
// Inbound is handled by our ACLs, the rest is dropped.
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
}
func (m *aclManager) seedInitialOptionalEntries() {

View File

@@ -13,7 +13,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -31,7 +31,7 @@ type Manager struct {
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
Address() iface.WGAddress
IsUserspaceBind() bool
}
@@ -52,7 +52,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return nil, fmt.Errorf("create router: %w", err)
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
}
@@ -96,22 +96,21 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
//
// Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
_ string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
@@ -126,7 +125,7 @@ func (m *Manager) AddRouteFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
@@ -167,7 +166,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
}
// Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error {
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -197,13 +196,13 @@ func (m *Manager) AllowNetbird() error {
}
_, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0},
"all",
nil,
nil,
firewall.ActionAccept,
"",
"",
)
if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err)
@@ -219,30 +218,6 @@ func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -10,15 +10,15 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() wgaddr.Address
AddressFunc func() iface.WGAddress
}
func (i *iFaceMock) Name() string {
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set")
}
func (i *iFaceMock) Address() wgaddr.Address {
func (i *iFaceMock) Address() iface.WGAddress {
if i.AddressFunc != nil {
return i.AddressFunc()
}
@@ -62,7 +62,7 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second)
defer func() {
err := manager.Close(nil)
err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
@@ -75,7 +75,7 @@ func TestIptablesManager(t *testing.T) {
IsRange: true,
Values: []uint16{8043, 8046},
}
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
@@ -97,17 +97,17 @@ func TestIptablesManager(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule")
err = manager.Close(nil)
err = manager.Reset(nil)
require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
require.NoError(t, err, "failed check chain exists")
if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
}
})
}
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -136,7 +136,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second)
defer func() {
err := manager.Close(nil)
err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
@@ -148,7 +148,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
port := &fw.Port{
Values: []uint16{443},
}
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -166,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
})
t.Run("reset check", func(t *testing.T) {
err = manager.Close(nil)
err = manager.Reset(nil)
require.NoError(t, err, "failed to reset")
})
}
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -204,7 +204,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second)
defer func() {
err := manager.Close(nil)
err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
@@ -216,7 +216,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
}

View File

@@ -15,8 +15,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -24,36 +23,22 @@ import (
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
tableMangle = "mangle"
tableFilter = "filter"
tableNat = "nat"
tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTFWD = "NETBIRD-RT-FWD"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
matchSet = "--match-set"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
fwdSuffix = "_fwd"
jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set"
)
type ruleInfo struct {
chain string
table string
rule []string
}
type routeFilteringRuleParams struct {
Sources []netip.Prefix
Destination netip.Prefix
@@ -77,7 +62,6 @@ type router struct {
legacyManagement bool
stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
@@ -85,7 +69,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
iptablesClient: iptablesClient,
rules: make(map[string][]string),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
r.ipsetCounter = refcounter.New(
@@ -121,7 +104,6 @@ func (r *router) init(stateManager *statemanager.Manager) error {
}
func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
@@ -129,7 +111,7 @@ func (r *router) AddRouteFiltering(
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
@@ -157,9 +139,9 @@ func (r *router) AddRouteFiltering(
var err error
if action == firewall.ActionDrop {
// after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
} else {
err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
}
if err != nil {
@@ -174,12 +156,12 @@ func (r *router) AddRouteFiltering(
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID()
ruleKey := rule.GetRuleID()
if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err)
}
delete(r.rules, ruleKey)
@@ -230,10 +212,6 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
@@ -260,10 +238,6 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
@@ -290,7 +264,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
}
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
@@ -303,7 +277,7 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleKey)
@@ -331,7 +305,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
@@ -369,11 +343,9 @@ func (r *router) cleanUpDefaultForwardRules() error {
chain string
table string
}{
{chainRTFWDIN, tableFilter},
{chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle},
{chainRTFWD, tableFilter},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainRTPRE, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
@@ -393,22 +365,16 @@ func (r *router) createContainers() error {
chain string
table string
}{
{chainRTFWDIN, tableFilter},
{chainRTFWDOUT, tableFilter},
{chainRTFWD, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
} {
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
@@ -449,6 +415,27 @@ func (r *router) addPostroutingRules() error {
return nil
}
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
if err := r.iptablesClient.NewChain(table, chain); err != nil {
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
}
return nil
}
func (r *router) getTableForChain(chain string) string {
switch chain {
case chainRTNAT:
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
}
func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
@@ -467,43 +454,28 @@ func (r *router) addJumpRules() error {
// Jump to NAT chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat postrouting jump rule: %v", err)
return fmt.Errorf("add nat jump rule: %v", err)
}
r.rules[jumpNatPost] = natRule
r.rules[jumpNat] = natRule
// Jump to mangle prerouting chain
// Jump to prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add mangle prerouting jump rule: %v", err)
return fmt.Errorf("add prerouting jump rule: %v", err)
}
r.rules[jumpManglePre] = preRule
// Jump to nat prerouting chain
rdrRule := []string{"-j", chainRTRDR}
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
return fmt.Errorf("add nat prerouting jump rule: %v", err)
}
r.rules[jumpNatPre] = rdrRule
r.rules[jumpPre] = preRule
return nil
}
func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
for _, ruleKey := range []string{jumpNat, jumpPre} {
if rule, exists := r.rules[ruleKey]; exists {
var table, chain string
switch ruleKey {
case jumpNatPost:
table = tableNat
chain = chainPOSTROUTING
case jumpManglePre:
table := tableNat
chain := chainPOSTROUTING
if ruleKey == jumpPre {
table = tableMangle
chain = chainPREROUTING
case jumpNatPre:
table = tableNat
chain = chainPREROUTING
default:
return fmt.Errorf("unknown jump rule: %s", ruleKey)
}
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
@@ -548,8 +520,6 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
}
r.rules[ruleKey] = rule
r.updateState()
return nil
}
@@ -565,7 +535,6 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
log.Debugf("marking rule %s not found", ruleKey)
}
r.updateState()
return nil
}
@@ -595,137 +564,6 @@ func (r *router) updateState() {
}
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
toDestination := rule.TranslatedAddress.String()
switch {
case len(rule.TranslatedPort.Values) == 0:
// no translated port, use original port
case len(rule.TranslatedPort.Values) == 1:
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
// need the "/originalport" suffix to avoid dnat port randomization
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
default:
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
proto := strings.ToLower(string(rule.Protocol))
rules := make(map[string]ruleInfo, 3)
// DNAT rule
dnatRule := []string{
"!", "-i", r.wgIface.Name(),
"-p", proto,
"-j", "DNAT",
"--to-destination", toDestination,
}
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
rules[ruleKey+dnatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
// SNAT rule
snatRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "MASQUERADE",
}
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+snatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTNAT,
rule: snatRule,
}
// Forward filtering rule, if fwd policy is DROP
forwardRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "ACCEPT",
}
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+fwdSuffix] = ruleInfo{
table: tableFilter,
chain: chainRTFWDOUT,
rule: forwardRule,
}
for key, ruleInfo := range rules {
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
log.Errorf("rollback failed: %v", rollbackErr)
}
return nil, fmt.Errorf("add rule %s: %w", key, err)
}
r.rules[key] = ruleInfo.rule
}
r.updateState()
return rule, nil
}
func (r *router) rollbackRules(rules map[string]ruleInfo) error {
var merr *multierror.Error
for key, ruleInfo := range rules {
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
// On rollback error, add to rules map for next cleanup
r.rules[key] = ruleInfo.rule
}
}
if merr != nil {
r.updateState()
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
}
delete(r.rules, ruleKey+dnatSuffix)
}
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
}
delete(r.rules, ruleKey+snatSuffix)
}
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
}
delete(r.rules, ruleKey+fwdSuffix)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string

View File

@@ -39,14 +39,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
}()
// Now 5 rules:
// 1. established rule forward in
// 2. estbalished rule forward out
// 3. jump rule to POST nat chain
// 4. jump rule to PRE mangle chain
// 5. jump rule to PRE nat chain
// 6. static outbound masquerade rule
// 7. static return masquerade rule
require.Len(t, manager.rules, 7, "should have created rules map")
// 1. established rule in forward chain
// 2. jump rule to NAT chain
// 3. jump rule to PRE chain
// 4. static outbound masquerade rule
// 5. static return masquerade rule
require.Len(t, manager.rules, 5, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -330,18 +328,18 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()]
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")

View File

@@ -12,6 +12,6 @@ type Rule struct {
}
// GetRuleID returns the rule id
func (r *Rule) ID() string {
func (r *Rule) GetRuleID() string {
return r.ruleID
}

View File

@@ -4,20 +4,21 @@ import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() wgaddr.Address {
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
@@ -61,7 +62,7 @@ func (s *ShutdownState) Cleanup() error {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
if err := ipt.Close(nil); err != nil {
if err := ipt.Reset(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err)
}

View File

@@ -26,8 +26,8 @@ const (
// Each firewall type for different OS can use different type
// of the properties to hold data of the created rule
type Rule interface {
// ID returns the rule id
ID() string
// GetRuleID returns the rule id
GetRuleID() string
}
// RuleDirection is the traffic direction which a rule is applied
@@ -65,13 +65,13 @@ type Manager interface {
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
AddPeerFiltering(
id []byte,
ip net.IP,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
ipsetName string,
comment string,
) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition
@@ -80,15 +80,7 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
) (Rule, error)
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
// DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error
@@ -102,23 +94,13 @@ type Manager interface {
// SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error
// Close closes the firewall manager
Close(stateManager *statemanager.Manager) error
// Reset firewall to the default state
Reset(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller
Flush() error
SetLogLevel(log.Level)
EnableRouting() error
DisableRouting() error
// AddDNATRule adds a DNAT rule
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -1,27 +0,0 @@
package manager
import (
"fmt"
"net/netip"
)
// ForwardRule todo figure out better place to this to avoid circular imports
type ForwardRule struct {
Protocol Protocol
DestinationPort Port
TranslatedAddress netip.Addr
TranslatedPort Port
}
func (r ForwardRule) ID() string {
id := fmt.Sprintf("%s;%s;%s;%s",
r.Protocol,
r.DestinationPort.String(),
r.TranslatedAddress.String(),
r.TranslatedPort.String())
return id
}
func (r ForwardRule) String() string {
return fmt.Sprintf("protocol: %s, destinationPort: %s, translatedAddress: %s, translatedPort: %s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String())
}

View File

@@ -1,12 +1,30 @@
package manager
import (
"fmt"
"strconv"
)
// Protocol is the protocol of the port
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
// ProtocolUnknown unknown protocol
ProtocolUnknown Protocol = "unknown"
)
// Port of the address for firewall rule
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Port struct {
// IsRange is true Values contains two values, the first is the start port, the second is the end port
IsRange bool
@@ -15,25 +33,6 @@ type Port struct {
Values []uint16
}
func NewPort(ports ...int) (*Port, error) {
if len(ports) == 0 {
return nil, fmt.Errorf("no port provided")
}
ports16 := make([]uint16, len(ports))
for i, port := range ports {
if port < 1 || port > 65535 {
return nil, fmt.Errorf("invalid port number: %d (must be between 1-65535)", port)
}
ports16[i] = uint16(port)
}
return &Port{
IsRange: len(ports) > 1,
Values: ports16,
}, nil
}
// String interface implementation
func (p *Port) String() string {
var ports string

View File

@@ -1,19 +0,0 @@
package manager
// Protocol is the protocol of the port
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
)

View File

@@ -84,13 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var ipset *nftables.Set
if ipsetName != "" {
@@ -102,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
}
newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
if err != nil {
return nil, err
}
@@ -127,7 +127,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
delete(m.rules, r.GetRuleID())
return m.rConn.Flush()
}
@@ -141,7 +141,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
delete(m.rules, r.GetRuleID())
return m.rConn.Flush()
}
@@ -176,7 +176,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
return err
}
delete(m.rules, r.ID())
delete(m.rules, r.GetRuleID())
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
@@ -256,6 +256,7 @@ func (m *AclManager) addIOFiltering(
dPort *firewall.Port,
action firewall.Action,
ipset *nftables.Set,
comment string,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok {
@@ -337,7 +338,7 @@ func (m *AclManager) addIOFiltering(
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(ruleId)
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
chain := m.chainInputRules
nftRule := m.rConn.AddRule(&nftables.Rule{

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -29,7 +29,7 @@ const (
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
Address() iface.WGAddress
IsUserspaceBind() bool
}
@@ -87,7 +87,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules.
// cleanup using Reset() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
@@ -113,13 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -129,11 +129,10 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
}
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
}
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
@@ -148,7 +147,7 @@ func (m *Manager) AddRouteFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
@@ -243,7 +242,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
}
// Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error {
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -324,14 +323,6 @@ func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
@@ -343,22 +334,6 @@ func (m *Manager) Flush() error {
return m.aclManager.Flush()
}
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -16,15 +16,15 @@ import (
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() wgaddr.Address
AddressFunc func() iface.WGAddress
}
func (i *iFaceMock) Name() string {
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set")
}
func (i *iFaceMock) Address() wgaddr.Address {
func (i *iFaceMock) Address() iface.WGAddress {
if i.AddressFunc != nil {
return i.AddressFunc()
}
@@ -65,7 +65,7 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second * 3)
defer func() {
err = manager.Close(nil)
err = manager.Reset(nil)
require.NoError(t, err, "failed to reset")
time.Sleep(time.Second)
}()
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
@@ -162,7 +162,7 @@ func TestNftablesManager(t *testing.T) {
// established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Close(nil)
err = manager.Reset(nil)
require.NoError(t, err, "failed to reset")
}
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
@@ -191,7 +191,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second * 3)
defer func() {
if err := manager.Close(nil); err != nil {
if err := manager.Reset(nil); err != nil {
t.Errorf("clear the manager state: %v", err)
}
time.Sleep(time.Second)
@@ -201,7 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@@ -274,7 +274,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Close(nil)
err := manager.Reset(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
@@ -283,11 +283,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,

View File

@@ -14,31 +14,23 @@ import (
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
tableNat = "nat"
chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameForward = "FORWARD"
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
)
const refreshRulesMapError = "refresh rules map: %w"
@@ -57,18 +49,16 @@ type router struct {
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
r := &router{
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
}
r.ipsetCounter = refcounter.New(
@@ -108,52 +98,7 @@ func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
var merr *multierror.Error
if err := r.removeAcceptForwardRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatPreroutingRules() error {
table := &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
}
chain := &nftables.Chain{
Name: chainNameNatPrerouting,
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
}
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from nat table: %w", err)
}
var merr *multierror.Error
// Delete rules that have our UserData suffix
for _, rule := range rules {
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), dnatSuffix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
return nberrors.FormatErrorOrNil(merr)
return r.removeAcceptForwardRules()
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
@@ -188,22 +133,14 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT,
})
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingRdr,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
}
// Add the single NAT rule that matches on mark
@@ -228,7 +165,6 @@ func (r *router) createContainers() error {
// AddRouteFiltering appends a nftables rule to the routing chain
func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
@@ -237,7 +173,7 @@ func (r *router) AddRouteFiltering(
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
@@ -345,7 +281,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleKey := rule.ID()
ruleKey := rule.GetRuleID()
nftRule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("route rule %s not found", ruleKey)
@@ -474,10 +410,6 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -904,10 +836,6 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -968,269 +896,6 @@ func (r *router) refreshRulesMap() error {
return nil
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
protoNum, err := protoToInt(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
return nil, err
}
r.addDnatMasq(rule, protoNum, ruleKey)
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
// TODO: find chains with drop policies and add rules there
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush rules: %w", err)
}
return &rule, nil
}
func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey string) error {
dnatExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
}
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
// shifted translated port is not supported in nftables, so we hand this over to xtables
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
return r.addXTablesRedirect(dnatExprs, ruleKey, rule)
}
}
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
if err != nil {
return err
}
dnatExprs = append(dnatExprs, additionalExprs...)
dnatExprs = append(dnatExprs,
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *router) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
switch {
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
return r.handlePortRange(rule)
case len(rule.TranslatedPort.Values) == 0:
return r.handleAddressOnly(rule)
case len(rule.TranslatedPort.Values) == 1:
return r.handleSinglePort(rule)
default:
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
}
func (r *router) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
&expr.Immediate{
Register: 3,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
},
}
return exprs, 2, 3, nil
}
func (r *router) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
return exprs, 0, 0, nil
}
func (r *router) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
}
return exprs, 2, 0, nil
}
func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule firewall.ForwardRule) error {
dnatExprs = append(dnatExprs,
&expr.Counter{},
&expr.Target{
Name: "DNAT",
Rev: 2,
Info: &xt.NatRange2{
NatRange: xt.NatRange{
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
MinIP: rule.TranslatedAddress.AsSlice(),
MaxIP: rule.TranslatedAddress.AsSlice(),
MinPort: rule.TranslatedPort.Values[0],
MaxPort: rule.TranslatedPort.Values[1],
},
BasePort: rule.DestinationPort.Values[0],
},
},
)
dnatRule := &nftables.Rule{
Table: &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
},
Chain: &nftables.Chain{
Name: chainNameNatPrerouting,
Table: r.filterTable,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
},
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey string) {
masqExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
masqExprs = append(masqExprs, &expr.Masq{})
masqRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: masqExprs,
UserData: []byte(ruleKey + snatSuffix),
}
r.conn.AddRule(masqRule)
r.rules[ruleKey+snatSuffix] = masqRule
}
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if merr == nil {
delete(r.rules, ruleKey+dnatSuffix)
delete(r.rules, ruleKey+snatSuffix)
}
return nberrors.FormatErrorOrNil(merr)
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32
@@ -1294,11 +959,15 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port.IsRange && len(port.Values) == 2 {
// Handle port range
exprs = append(exprs,
&expr.Range{
Op: expr.CmpOpEq,
&expr.Cmp{
Op: expr.CmpOpGte,
Register: 1,
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
},
&expr.Cmp{
Op: expr.CmpOpLte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
},
)
} else {

View File

@@ -38,7 +38,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
require.NoError(t, manager.Reset(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -127,7 +127,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
require.NoError(t, manager.Reset(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
@@ -319,7 +319,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
})
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()]
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
t.Log("Internal rule expressions:")
@@ -336,7 +336,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
var nftRule *nftables.Rule
for _, rule := range rules {
if string(rule.UserData) == ruleKey.ID() {
if string(rule.UserData) == ruleKey.GetRuleID() {
nftRule = rule
break
}
@@ -595,20 +595,16 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
payloadFound = true
}
case *expr.Range:
if port.IsRange && len(port.Values) == 2 {
fromPort := binary.BigEndian.Uint16(ex.FromData)
toPort := binary.BigEndian.Uint16(ex.ToData)
if fromPort == port.Values[0] && toPort == port.Values[1] {
case *expr.Cmp:
if port.IsRange {
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
portMatchFound = true
}
}
case *expr.Cmp:
if !port.IsRange {
} else {
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
portValue := binary.BigEndian.Uint16(ex.Data)
for _, p := range port.Values {
if p == portValue {
if uint16(p) == portValue {
portMatchFound = true
break
}

View File

@@ -16,6 +16,6 @@ type Rule struct {
}
// GetRuleID returns the rule id
func (r *Rule) ID() string {
func (r *Rule) GetRuleID() string {
return r.ruleID
}

View File

@@ -3,20 +3,21 @@ package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() wgaddr.Address {
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
@@ -38,7 +39,7 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create nftables manager: %w", err)
}
if err := nft.Close(nil); err != nil {
if err := nft.Reset(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err)
}

View File

@@ -4,36 +4,39 @@ package uspfilter
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error {
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.logger != nil {
@@ -45,7 +48,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
}
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)
return m.nativeFirewall.Reset(stateManager)
}
return nil
}

View File

@@ -3,7 +3,6 @@ package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
@@ -23,30 +22,30 @@ const (
)
// Reset firewall to the default state
func (m *Manager) Close(*statemanager.Manager) error {
func (m *Manager) Reset(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
if m.forwarder != nil {
m.forwarder.Stop()
}
if m.logger != nil {

View File

@@ -3,14 +3,14 @@ package common
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() wgaddr.Address
Address() iface.WGAddress
GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice
}

View File

@@ -1,27 +1,20 @@
// common.go
package conntrack
import (
"fmt"
"net/netip"
"net"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
FlowId uuid.UUID
Direction nftypes.Direction
SourceIP netip.Addr
DestIP netip.Addr
lastSeen atomic.Int64
PacketsTx atomic.Uint64
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64 // Unix nano for atomic access
}
// these small methods will be inlined by the compiler
@@ -31,17 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano())
}
// UpdateCounters safely updates the packet and byte counters
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
if direction == nftypes.Egress {
b.PacketsTx.Add(1)
b.BytesTx.Add(uint64(bytes))
} else {
b.PacketsRx.Add(1)
b.BytesRx.Add(uint64(bytes))
}
}
// GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load())
@@ -53,14 +35,92 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
return time.Since(lastSeen) > timeout
}
// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte
// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}
// ConnKey uniquely identifies a connection
type ConnKey struct {
SrcIP netip.Addr
DstIP netip.Addr
SrcIP IPAddr
DstIP IPAddr
SrcPort uint16
DstPort uint16
}
func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
return ConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
}
// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}
// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}
// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}
// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}
// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
}

View File

@@ -1,67 +1,94 @@
package conntrack
import (
"context"
"net/netip"
"net"
"testing"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/internal/netflow"
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = MakeIPAddr(ip)
}
})
b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := MakeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
}
})
b.Run("IPPool", func(b *testing.B) {
pool := NewPreallocatedIPs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := pool.Get()
pool.Put(ip)
}
})
}
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close()
// Generate different IPs
srcIPs := make([]netip.Addr, 100)
dstIPs := make([]netip.Addr, 100)
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
}
}
})
b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close()
// Generate different IPs
srcIPs := make([]netip.Addr, 100)
dstIPs := make([]netip.Addr, 100)
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
}
}
})

View File

@@ -1,17 +1,13 @@
package conntrack
import (
"context"
"fmt"
"net/netip"
"net"
"sync"
"time"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@@ -23,20 +19,18 @@ const (
// ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct {
SrcIP netip.Addr
DstIP netip.Addr
ID uint16
}
func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
// Supports both IPv4 and IPv6
SrcIP [16]byte
DstIP [16]byte
Sequence uint16 // ICMP sequence number
ID uint16 // ICMP identifier
}
// ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct {
BaseConnTrack
ICMPType uint8
ICMPCode uint8
Sequence uint16
ID uint16
}
// ICMPTracker manages ICMP connection states
@@ -45,201 +39,131 @@ type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
flowLogger nftypes.FlowLogger
done chan struct{}
ipPool *PreallocatedIPs
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
if timeout == 0 {
timeout = DefaultICMPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &ICMPTracker{
logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
tickerCancel: cancel,
flowLogger: flowLogger,
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
}
go tracker.cleanupRoutine(ctx)
go tracker.cleanupRoutine()
return tracker
}
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
key := ICMPConnKey{
SrcIP: srcIP,
DstIP: dstIP,
ID: id,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists {
return
}
typ, code := typecode.Type(), typecode.Code()
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
conn := &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
ICMPType: typ,
ICMPCode: code,
}
conn.UpdateLastSeen()
// TrackOutbound records an outbound ICMP Echo Request
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
key := makeICMPKey(srcIP, dstIP, id, seq)
t.mutex.Lock()
t.connections[key] = conn
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
},
ID: id,
Sequence: seq,
}
conn.UpdateLastSeen()
t.connections[key] = conn
t.logger.Trace("New ICMP connection %v", key)
}
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
conn.UpdateLastSeen()
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
key := ICMPConnKey{
SrcIP: dstIP,
DstIP: srcIP,
ID: id,
}
key := makeICMPKey(dstIP, srcIP, id, seq)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) {
if !exists {
return false
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
if conn.timeoutExceeded(t.timeout) {
return false
}
return true
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id &&
conn.Sequence == seq
}
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
defer t.tickerCancel()
func (t *ICMPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-ctx.Done():
case <-t.done:
return
}
}
}
func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.tickerCancel()
t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
fields := nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction,
Protocol: nftypes.ICMP,
SourceIP: srcIP,
DestIP: dstIP,
ICMPType: typ,
ICMPCode: code,
// makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
return ICMPConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
ID: id,
Sequence: seq,
}
if direction == nftypes.Ingress {
fields.RxPackets = 1
fields.RxBytes = uint64(size)
} else {
fields.TxPackets = 1
fields.TxBytes = uint64(size)
}
t.flowLogger.StoreEvent(fields)
}

View File

@@ -1,39 +1,39 @@
package conntrack
import (
"net/netip"
"net"
"testing"
)
func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
}
})
}

View File

@@ -3,16 +3,12 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections
import (
"context"
"net/netip"
"net"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@@ -43,35 +39,6 @@ const (
// TCPState represents the state of a TCP connection
type TCPState int
func (s TCPState) String() string {
switch s {
case TCPStateNew:
return "New"
case TCPStateSynSent:
return "SYN Sent"
case TCPStateSynReceived:
return "SYN Received"
case TCPStateEstablished:
return "Established"
case TCPStateFinWait1:
return "FIN Wait 1"
case TCPStateFinWait2:
return "FIN Wait 2"
case TCPStateClosing:
return "Closing"
case TCPStateTimeWait:
return "Time Wait"
case TCPStateCloseWait:
return "Close Wait"
case TCPStateLastAck:
return "Last ACK"
case TCPStateClosed:
return "Closed"
default:
return "Unknown"
}
}
const (
TCPStateNew TCPState = iota
TCPStateSynSent
@@ -86,14 +53,19 @@ const (
TCPStateClosed
)
// TCPConnKey uniquely identifies a TCP connection
type TCPConnKey struct {
SrcIP [16]byte
DstIP [16]byte
SrcPort uint16
DstPort uint16
}
// TCPConnTrack represents a TCP connection state
type TCPConnTrack struct {
BaseConnTrack
SourcePort uint16
DestPort uint16
State TCPState
established atomic.Bool
tombstone atomic.Bool
sync.RWMutex
}
@@ -107,126 +79,78 @@ func (t *TCPConnTrack) SetEstablished(state bool) {
t.established.Store(state)
}
// IsTombstone safely checks if the connection is marked for deletion
func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
}
// TCPTracker manages TCP connection states
type TCPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
done chan struct{}
timeout time.Duration
flowLogger nftypes.FlowLogger
ipPool *PreallocatedIPs
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
if timeout == 0 {
timeout = DefaultTCPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
done: make(chan struct{}),
timeout: timeout,
flowLogger: flowLogger,
ipPool: NewPreallocatedIPs(),
}
go tracker.cleanupRoutine(ctx)
go tracker.cleanupRoutine()
return tracker
}
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.Unlock()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
}
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.established.Store(false)
conn.tombstone.Store(false)
t.logger.Trace("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction == nftypes.Egress)
// TrackOutbound processes an outbound TCP packet and updates connection state
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
// Create key before lock
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
t.mutex.Lock()
t.connections[key] = conn
conn, exists := t.connections[key]
if !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
State: TCPStateNew,
}
conn.UpdateLastSeen()
conn.established.Store(false)
t.connections[key] = conn
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
}
t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn, ruleID)
// Lock individual connection for state update
conn.Lock()
t.updateState(conn, flags, true)
conn.Unlock()
conn.UpdateLastSeen()
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
@@ -235,26 +159,22 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
return false
}
// Handle RST flag specially - it always causes transition to closed
// Handle RST packets
if flags&TCPRst != 0 {
if conn.IsTombstone() {
conn.Lock()
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
return true
}
conn.Lock()
conn.SetTombstone()
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
return true
return false
}
conn.Lock()
t.updateState(key, conn, flags, false)
t.updateState(conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock()
@@ -263,17 +183,18 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
}
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
conn.UpdateLastSeen()
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
// Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetEstablished(false)
state := conn.State
defer func() {
if state != conn.State {
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
}
}()
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
return
}
switch state {
switch conn.State {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent
@@ -282,11 +203,11 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound {
conn.State = TCPStateEstablished
conn.SetEstablished(true)
conn.State = TCPStateSynReceived
} else {
// Simultaneous open
conn.State = TCPStateSynReceived
conn.State = TCPStateEstablished
conn.SetEstablished(true)
}
}
@@ -304,32 +225,22 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.State = TCPStateCloseWait
}
conn.SetEstablished(false)
} else if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateFinWait1:
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
// Simultaneous close - both sides sent FIN
conn.State = TCPStateClosing
case flags&TCPFin != 0:
conn.State = TCPStateFinWait2
case flags&TCPAck != 0:
conn.State = TCPStateFinWait2
case flags&TCPRst != 0:
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateFinWait2:
if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateClosing:
@@ -337,8 +248,8 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.State = TCPStateTimeWait
// Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
}
case TCPStateCloseWait:
@@ -349,12 +260,17 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
case TCPStateLastAck:
if flags&TCPAck != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
// Send close event for gracefully closed connections
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("TCP connection %s closed gracefully", key)
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
}
case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
}
}
@@ -399,14 +315,12 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return false
}
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
func (t *TCPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-ctx.Done():
case <-t.done:
return
}
}
@@ -417,12 +331,6 @@ func (t *TCPTracker) cleanup() {
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.IsTombstone() {
// Clean up tombstoned connections without sending an event
delete(t.connections, key)
continue
}
var timeout time.Duration
switch {
case conn.State == TCPStateTimeWait:
@@ -433,26 +341,29 @@ func (t *TCPTracker) cleanup() {
timeout = TCPHandshakeTimeout
}
if conn.timeoutExceeded(timeout) {
lastSeen := conn.GetLastSeen()
if time.Since(lastSeen) > timeout {
// Return IPs to pool
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
// event already handled by state change
if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.tickerCancel()
t.cleanupTicker.Stop()
close(t.done)
// Clean up all remaining IPs
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
@@ -470,21 +381,3 @@ func isValidFlagCombination(flags uint8) bool {
return true
}
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.TCP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -1,7 +1,7 @@
package conntrack
import (
"net/netip"
"net"
"testing"
"time"
@@ -9,11 +9,11 @@ import (
)
func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
@@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
})
}
@@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper()
// Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
// Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
// Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
require.True(t, valid, "Data should be allowed after handshake")
},
},
@@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "FIN should be allowed")
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
},
},
{
@@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets
@@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "Final ACKs should be allowed")
},
},
@@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
tt.test(t)
})
}
@@ -162,11 +162,11 @@ func TestTCPStateMachine(t *testing.T) {
}
func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
@@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established",
setupState: func() {
// Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
},
wantValid: true,
desc: "Should accept RST for established connection",
@@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection",
setupState: func() {},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
},
wantValid: false,
desc: "Should reject RST without connection",
@@ -208,12 +208,7 @@ func TestRSTHandling(t *testing.T) {
tt.sendRST()
// Verify connection state is as expected
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn := tracker.connections[key]
if tt.wantValid {
require.NotNil(t, conn)
@@ -225,63 +220,63 @@ func TestRSTHandling(t *testing.T) {
}
// Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
}
i++
}
@@ -292,14 +287,14 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
// Wait for connections to expire

View File

@@ -1,15 +1,11 @@
package conntrack
import (
"context"
"net/netip"
"net"
"sync"
"time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@@ -22,8 +18,6 @@ const (
// UDPConnTrack represents a UDP connection state
type UDPConnTrack struct {
BaseConnTrack
SourcePort uint16
DestPort uint16
}
// UDPTracker manages UDP connection states
@@ -32,125 +26,89 @@ type UDPTracker struct {
connections map[ConnKey]*UDPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
flowLogger nftypes.FlowLogger
done chan struct{}
ipPool *PreallocatedIPs
}
// NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
if timeout == 0 {
timeout = DefaultUDPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &UDPTracker{
logger: logger,
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
tickerCancel: cancel,
flowLogger: flowLogger,
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
}
go tracker.cleanupRoutine(ctx)
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
}
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists {
return
}
conn := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.UpdateLastSeen()
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
t.mutex.Lock()
t.connections[key] = conn
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
}
conn.UpdateLastSeen()
t.connections[key] = conn
t.logger.Trace("New UDP connection: %v", conn)
}
t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
conn.UpdateLastSeen()
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) {
if !exists {
return false
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
if conn.timeoutExceeded(t.timeout) {
return false
}
return true
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
func (t *UDPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-ctx.Done():
case <-t.done:
return
}
}
@@ -162,58 +120,44 @@ func (t *UDPTracker) cleanup() {
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() {
t.tickerCancel()
t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
// GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := t.connections[key]
return conn, exists
if !exists {
return nil, false
}
return conn, true
}
// Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration {
return t.timeout
}
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.UDP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -1,8 +1,7 @@
package conntrack
import (
"context"
"net/netip"
"net"
"testing"
"time"
@@ -30,59 +29,54 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
tracker := NewUDPTracker(tt.timeout, logger)
assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.tickerCancel)
assert.NotNil(t, tracker.done)
})
}
}
func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := netip.MustParseAddr("192.168.1.3")
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
// Verify connection was tracked
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := tracker.connections[key]
require.True(t, exists)
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
assert.True(t, conn.SourceIP.Equal(srcIP))
assert.True(t, conn.DestIP.Equal(dstIP))
assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort)
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
}
func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
tracker := NewUDPTracker(1*time.Second, logger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := netip.MustParseAddr("192.168.1.3")
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
// Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tests := []struct {
name string
srcIP netip.Addr
dstIP netip.Addr
srcIP net.IP
dstIP net.IP
srcPort uint16
dstPort uint16
sleep time.Duration
@@ -99,7 +93,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
},
{
name: "invalid source IP",
srcIP: netip.MustParseAddr("192.168.1.4"),
srcIP: net.ParseIP("192.168.1.4"),
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
@@ -109,7 +103,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
{
name: "invalid destination IP",
srcIP: dstIP,
dstIP: netip.MustParseAddr("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.4"),
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
@@ -149,7 +143,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 {
time.Sleep(tt.sleep)
}
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
assert.Equal(t, tt.want, got)
})
}
@@ -160,45 +154,42 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond
ctx, tickerCancel := context.WithCancel(context.Background())
defer tickerCancel()
// Create tracker with custom cleanup interval
tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval),
tickerCancel: tickerCancel,
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
logger: logger,
flowLogger: flowLogger,
}
// Start cleanup routine
go tracker.cleanupRoutine(ctx)
go tracker.cleanupRoutine()
// Add some connections
connections := []struct {
srcIP netip.Addr
dstIP netip.Addr
srcIP net.IP
dstIP net.IP
srcPort uint16
dstPort uint16
}{
{
srcIP: netip.MustParseAddr("192.168.1.2"),
dstIP: netip.MustParseAddr("192.168.1.3"),
srcIP: net.ParseIP("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"),
srcPort: 12345,
dstPort: 53,
},
{
srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: netip.MustParseAddr("192.168.1.5"),
srcIP: net.ParseIP("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"),
srcPort: 12346,
dstPort: 53,
},
}
for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
}
// Verify initial connections
@@ -220,33 +211,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
}
})
}

View File

@@ -1,8 +1,6 @@
package forwarder
import (
"fmt"
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -81,10 +79,3 @@ func (e *endpoint) AddHeader(*stack.PacketBuffer) {
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true
}
type epID stack.TransportEndpointID
func (i epID) String() string {
// src and remote is swapped
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
}

View File

@@ -18,7 +18,6 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@@ -30,7 +29,6 @@ const (
type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
@@ -40,7 +38,7 @@ type Forwarder struct {
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
@@ -104,10 +102,9 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
udpForwarder: newUDPForwarder(mtu, logger),
ctx: ctx,
cancel: cancel,
netstack: netstack,

View File

@@ -3,30 +3,14 @@ package forwarder
import (
"context"
"net"
"net/netip"
"time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
@@ -34,7 +18,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
@@ -48,31 +32,47 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data)
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
// For Echo Requests, send and handle response
switch icmpHdr.Type() {
case header.ICMPv4Echo:
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
case header.ICMPv4EchoReply:
// dont process our own replies
return true
default:
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return
return true
}
response := make([]byte, f.endpoint.mtu)
@@ -81,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err)
}
return
return true
}
ipHdr := make([]byte, header.IPv4MinimumSize)
@@ -101,27 +101,9 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err)
return
return true
}
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
}
// sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.ICMP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
ICMPType: icmpType,
ICMPCode: icmpCode,
// TODO: get packets/bytes
})
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
return true
}

View File

@@ -5,38 +5,24 @@ import (
"fmt"
"io"
"net"
"net/netip"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// handleTCP is called by the TCP forwarder for new connections.
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
return
}
@@ -58,13 +44,12 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
f.logger.Trace("forwarder: established TCP connection %v", id)
go f.proxyTCP(id, inConn, outConn, ep, flowID)
go f.proxyTCP(id, inConn, outConn, ep)
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
@@ -73,8 +58,6 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}()
// Create context for managing the proxy goroutines
@@ -95,38 +78,13 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
return
}
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}

View File

@@ -5,12 +5,10 @@ import (
"errors"
"fmt"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -18,7 +16,6 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
@@ -31,17 +28,15 @@ type udpPacketConn struct {
lastSeen atomic.Int64
cancel context.CancelFunc
ep tcpip.Endpoint
flowID uuid.UUID
}
type udpForwarder struct {
sync.RWMutex
logger *nblog.Logger
flowLogger nftypes.FlowLogger
conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool
ctx context.Context
cancel context.CancelFunc
logger *nblog.Logger
conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool
ctx context.Context
cancel context.CancelFunc
}
type idleConn struct {
@@ -49,14 +44,13 @@ type idleConn struct {
conn *udpPacketConn
}
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{
logger: logger,
flowLogger: flowLogger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx,
cancel: cancel,
logger: logger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx,
cancel: cancel,
bufPool: sync.Pool{
New: func() any {
b := make([]byte, mtu)
@@ -78,10 +72,10 @@ func (f *udpForwarder) Stop() {
for id, conn := range f.conns {
conn.cancel()
if err := conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
}
if err := conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
conn.ep.Close()
@@ -112,10 +106,10 @@ func (f *udpForwarder) cleanup() {
for _, idle := range idleConns {
idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
}
if err := idle.conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
}
idle.conn.ep.Close()
@@ -124,7 +118,7 @@ func (f *udpForwarder) cleanup() {
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
}
}
}
@@ -143,24 +137,14 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
f.logger.Trace("forwarder: existing UDP connection for %v", id)
return
}
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
// TODO: Send ICMP error message
return
}
@@ -171,7 +155,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if epErr != nil {
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
return
}
@@ -184,7 +168,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
outConn: outConn,
cancel: connCancel,
ep: ep,
flowID: flowID,
}
pConn.updateLastSeen()
@@ -194,20 +177,17 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.udpForwarder.Unlock()
pConn.cancel()
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
f.logger.Trace("forwarder: established UDP connection to %v", id)
go f.proxyUDP(connCtx, pConn, id, ep)
}
@@ -215,10 +195,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
defer func() {
pConn.cancel()
if err := pConn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
}
if err := pConn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
}
ep.Close()
@@ -226,8 +206,6 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
}()
errChan := make(chan error, 2)
@@ -242,43 +220,17 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
return
}
}
// sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}
func (c *udpPacketConn) updateLastSeen() {
c.lastSeen.Store(time.Now().UnixNano())
}
@@ -293,29 +245,33 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
defer bufPool.Put(bufp)
buffer := *bufp
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set write deadline: %w", err)
}
for {
if ctx.Err() != nil {
select {
case <-ctx.Done():
return ctx.Err()
}
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
n, err := src.Read(buffer)
if err != nil {
if isTimeout(err) {
continue
default:
n, err := src.Read(buffer)
if err != nil {
if isTimeout(err) {
continue
}
return fmt.Errorf("read from %s: %w", direction, err)
}
return fmt.Errorf("read from %s: %w", direction, err)
}
_, err = dst.Write(buffer[:n])
if err != nil {
return fmt.Errorf("write to %s: %w", direction, err)
}
_, err = dst.Write(buffer[:n])
if err != nil {
return fmt.Errorf("write to %s: %w", direction, err)
}
c.updateLastSeen()
c.updateLastSeen()
}
}
}

View File

@@ -3,7 +3,6 @@ package uspfilter
import (
"fmt"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
@@ -32,9 +31,13 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high] |= 1 << (low % 32)
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := (uint16(ip[0]) << 8) | uint16(ip[1])
low := (uint16(ip[2]) << 8) | uint16(ip[3])
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
ipv4 := ip.To4()
if ipv4 == nil {
return false
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
}
@@ -119,12 +122,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
return nil
}
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
m.mu.RLock()
defer m.mu.RUnlock()
if ip.Is4() {
return m.checkBitmapBit(ip.AsSlice())
if ipv4 := ip.To4(); ipv4 != nil {
return m.checkBitmapBit(ipv4)
}
return false

View File

@@ -2,91 +2,90 @@ package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface"
)
func TestLocalIPManager(t *testing.T) {
tests := []struct {
name string
setupAddr wgaddr.Address
testIP netip.Addr
setupAddr iface.WGAddress
testIP net.IP
expected bool
}{
{
name: "Localhost range",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.2"),
testIP: net.ParseIP("127.0.0.2"),
expected: true,
},
{
name: "Localhost standard address",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.1"),
testIP: net.ParseIP("127.0.0.1"),
expected: true,
},
{
name: "Localhost range edge",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.255.255.255"),
testIP: net.ParseIP("127.255.255.255"),
expected: true,
},
{
name: "Local IP matches",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.1"),
testIP: net.ParseIP("192.168.1.1"),
expected: true,
},
{
name: "Local IP doesn't match",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.2"),
testIP: net.ParseIP("192.168.1.2"),
expected: false,
},
{
name: "IPv6 address",
setupAddr: wgaddr.Address{
setupAddr: iface.WGAddress{
IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
},
testIP: netip.MustParseAddr("fe80::1"),
testIP: net.ParseIP("fe80::1"),
expected: false,
},
}
@@ -96,7 +95,7 @@ func TestLocalIPManager(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() wgaddr.Address {
AddressFunc: func() iface.WGAddress {
return tt.setupAddr
},
}
@@ -175,7 +174,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
result := manager.IsLocalIP(net.ParseIP(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
})
}

View File

@@ -1,4 +1,4 @@
// Package log provides a high-performance, non-blocking logger for userspace networking
// Package logger provides a high-performance, non-blocking logger for userspace networking
package log
import (
@@ -13,12 +13,13 @@ import (
)
const (
maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2
maxBatchSize = 1024 * 16 // 16KB max batch size
maxMessageSize = 1024 * 2 // 2KB per message
bufferSize = 1024 * 256 // 256KB ring buffer
defaultFlushInterval = 2 * time.Second
logChannelSize = 1000
)
// Level represents log severity
type Level uint32
const (
@@ -41,37 +42,32 @@ var levelStrings = map[Level]string{
LevelTrace: "TRAC",
}
type logMessage struct {
level Level
format string
args []any
}
// Logger is a high-performance, non-blocking logger
type Logger struct {
output io.Writer
level atomic.Uint32
msgChannel chan logMessage
shutdown chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
bufPool sync.Pool
output io.Writer
level atomic.Uint32
buffer *ringBuffer
shutdown chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
// Reusable buffer pool for formatting messages
bufPool sync.Pool
}
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
msgChannel: make(chan logMessage, logChannelSize),
shutdown: make(chan struct{}),
output: logrusLogger.Out,
buffer: newRingBuffer(bufferSize),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() any {
New: func() interface{} {
// Pre-allocate buffer for message formatting
b := make([]byte, 0, maxMessageSize)
return &b
},
},
}
logrusLevel := logrusLogger.GetLevel()
l.level.Store(uint32(logrusLevel))
level := levelStrings[Level(logrusLevel)]
@@ -83,149 +79,97 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
return l
}
// SetLevel sets the logging level
func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level))
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
func (l *Logger) log(level Level, format string, args ...any) {
select {
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
default:
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
*buf = (*buf)[:0]
// Timestamp
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
// Level
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
// Message
if len(args) > 0 {
*buf = append(*buf, fmt.Sprintf(format, args...)...)
} else {
*buf = append(*buf, format...)
}
*buf = append(*buf, '\n')
}
// Error logs a message at error level
func (l *Logger) Error(format string, args ...any) {
func (l *Logger) log(level Level, format string, args ...interface{}) {
bufp := l.bufPool.Get().(*[]byte)
l.formatMessage(bufp, level, format, args...)
if len(*bufp) > maxMessageSize {
*bufp = (*bufp)[:maxMessageSize]
}
_, _ = l.buffer.Write(*bufp)
l.bufPool.Put(bufp)
}
func (l *Logger) Error(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...)
}
}
// Warn logs a message at warning level
func (l *Logger) Warn(format string, args ...any) {
func (l *Logger) Warn(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...)
}
}
// Info logs a message at info level
func (l *Logger) Info(format string, args ...any) {
func (l *Logger) Info(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...)
}
}
// Debug logs a message at debug level
func (l *Logger) Debug(format string, args ...any) {
func (l *Logger) Debug(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...)
}
}
// Trace logs a message at trace level
func (l *Logger) Trace(format string, args ...any) {
func (l *Logger) Trace(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...)
}
}
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
var msg string
if len(args) > 0 {
msg = fmt.Sprintf(format, args...)
} else {
msg = format
}
*buf = append(*buf, msg...)
*buf = append(*buf, '\n')
if len(*buf) > maxMessageSize {
*buf = (*buf)[:maxMessageSize]
}
}
// processMessage handles a single log message and adds it to the buffer
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
bufp := l.bufPool.Get().(*[]byte)
defer l.bufPool.Put(bufp)
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
if len(*buffer)+len(*bufp) > maxBatchSize {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
*buffer = append(*buffer, *bufp...)
}
// flushBuffer writes the accumulated buffer to output
func (l *Logger) flushBuffer(buffer *[]byte) {
if len(*buffer) > 0 {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
}
// processBatch processes as many messages as possible without blocking
func (l *Logger) processBatch(buffer *[]byte) {
for len(*buffer) < maxBatchSize {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
default:
return
}
}
}
// handleShutdown manages the graceful shutdown sequence with timeout
func (l *Logger) handleShutdown(buffer *[]byte) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
for {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
case <-ctx.Done():
l.flushBuffer(buffer)
return
}
if len(l.msgChannel) == 0 {
l.flushBuffer(buffer)
return
}
}
}
// worker is the main goroutine that processes log messages
// worker periodically flushes the buffer
func (l *Logger) worker() {
defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop()
buffer := make([]byte, 0, maxBatchSize)
buf := make([]byte, 0, maxBatchSize)
for {
select {
case <-l.shutdown:
l.handleShutdown(&buffer)
return
case <-ticker.C:
l.flushBuffer(&buffer)
case msg := <-l.msgChannel:
l.processMessage(msg, &buffer)
l.processBatch(&buffer)
// Read accumulated messages
n, _ := l.buffer.Read(buf[:cap(buf)])
if n == 0 {
continue
}
// Write batch
_, _ = l.output.Write(buf[:n])
}
}
}

View File

@@ -1,121 +0,0 @@
package log_test
import (
"context"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
type discard struct{}
func (d *discard) Write(p []byte) (n int, err error) {
return len(p), nil
}
func BenchmarkLogger(b *testing.B) {
simpleMessage := "Connection established"
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4 // TCPStateEstablished
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
protocol := "TCP"
direction := "outbound"
flags := uint16(0x18) // ACK + PSH
sequence := uint32(123456789)
acknowledged := uint32(987654321)
payloadSize := 1460
fragmented := false
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
b.Run("SimpleMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(simpleMessage)
}
})
b.Run("ConntrackMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
b.Run("ComplexMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
}
})
}
// BenchmarkLoggerParallel tests the logger under concurrent load
func BenchmarkLoggerParallel(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
}
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
func BenchmarkLoggerBurst(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
}
}
func createTestLogger() *log.Logger {
logrusLogger := logrus.New()
logrusLogger.SetOutput(&discard{})
logrusLogger.SetLevel(logrus.TraceLevel)
return log.NewFromLogrus(logrusLogger)
}
func cleanupLogger(logger *log.Logger) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = logger.Stop(ctx)
}

View File

@@ -0,0 +1,85 @@
package log
import "sync"
// ringBuffer is a simple ring buffer implementation
type ringBuffer struct {
buf []byte
size int
r, w int64 // Read and write positions
mu sync.Mutex
}
func newRingBuffer(size int) *ringBuffer {
return &ringBuffer{
buf: make([]byte, size),
size: size,
}
}
func (r *ringBuffer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
r.mu.Lock()
defer r.mu.Unlock()
if len(p) > r.size {
p = p[:r.size]
}
n = len(p)
// Write data, handling wrap-around
pos := int(r.w % int64(r.size))
writeLen := min(len(p), r.size-pos)
copy(r.buf[pos:], p[:writeLen])
// If we have more data and need to wrap around
if writeLen < len(p) {
copy(r.buf, p[writeLen:])
}
// Update write position
r.w += int64(n)
return n, nil
}
func (r *ringBuffer) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.w == r.r {
return 0, nil
}
// Calculate available data accounting for wraparound
available := int(r.w - r.r)
if available < 0 {
available += r.size
}
available = min(available, r.size)
// Limit read to buffer size
toRead := min(available, len(p))
if toRead == 0 {
return 0, nil
}
// Read data, handling wrap-around
pos := int(r.r % int64(r.size))
readLen := min(toRead, r.size-pos)
n = copy(p, r.buf[pos:pos+readLen])
// If we need more data and need to wrap around
if readLen < toRead {
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
}
// Update read position
r.r += int64(n)
return n, nil
}

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"net"
"net/netip"
"github.com/google/gopacket"
@@ -11,26 +12,25 @@ import (
// PeerRule to handle management of rules
type PeerRule struct {
id string
mgmtId []byte
ip netip.Addr
ip net.IP
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType
sPort *firewall.Port
dPort *firewall.Port
drop bool
comment string
udpHook func([]byte) bool
}
// ID returns the rule id
func (r *PeerRule) ID() string {
// GetRuleID returns the rule id
func (r *PeerRule) GetRuleID() string {
return r.id
}
type RouteRule struct {
id string
mgmtId []byte
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
@@ -39,7 +39,7 @@ type RouteRule struct {
action firewall.Action
}
// ID returns the rule id
func (r *RouteRule) ID() string {
// GetRuleID returns the rule id
func (r *RouteRule) GetRuleID() string {
return r.id
}

View File

@@ -2,7 +2,7 @@ package uspfilter
import (
"fmt"
"net/netip"
"net"
"time"
"github.com/google/gopacket"
@@ -53,8 +53,8 @@ type TraceResult struct {
}
type PacketTrace struct {
SourceIP netip.Addr
DestinationIP netip.Addr
SourceIP net.IP
DestinationIP net.IP
Protocol string
SourcePort uint16
DestinationPort uint16
@@ -72,8 +72,8 @@ type TCPState struct {
}
type PacketBuilder struct {
SrcIP netip.Addr
DstIP netip.Addr
SrcIP net.IP
DstIP net.IP
Protocol fw.Protocol
SrcPort uint16
DstPort uint16
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
Version: 4,
TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(),
SrcIP: p.SrcIP,
DstIP: p.DstIP,
}
}
@@ -260,30 +260,28 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
}
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace
}
if m.localipmanager.IsLocalIP(dstIP) {
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
if !m.handleRouting(trace) {
return trace
}
if m.nativeRouter.Load() {
if m.nativeRouter {
return m.handleNativeRouter(trace)
}
return m.handleRouteACLs(trace, d, srcIP, dstIP)
}
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
msg := "No existing connection found"
if allowed {
msg = m.buildConntrackStateMessage(d)
@@ -311,46 +309,32 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
return msg
}
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
if !m.localForwarding {
trace.AddResult(StageRouting, "Local forwarding disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
return true
}
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
strRuleId := "<no id>"
if ruleId != nil {
strRuleId = string(ruleId)
}
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
msg := "Allowed by peer ACL rules"
if blocked {
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
trace.AddResult(StagePeerACL, msg, false)
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
return true
msg = "Blocked by peer ACL rules"
}
trace.AddResult(StagePeerACL, msg, !blocked)
trace.AddResult(StagePeerACL, msg, true)
// Handle netstack mode
if m.netstack {
switch {
case !m.localForwarding:
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
case m.forwarder.Load() != nil:
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
default:
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
}
return true
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
}
// In normal mode, packets are allowed through for local delivery
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
return true
}
func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled.Load() {
if !m.routingEnabled {
trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false
@@ -366,23 +350,18 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
return trace
}
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto, _ := getProtocolFromPacket(d)
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
proto := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
strId := string(id)
if id == nil {
strId = "<no id>"
}
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
msg := "Allowed by route ACLs"
if !allowed {
msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
msg = "Blocked by route ACLs"
}
trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder.Load() != nil {
if allowed && m.forwarder != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
}
@@ -401,7 +380,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.processOutgoingHooks(packetData, 0)
dropped := m.processOutgoingHooks(packetData)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {

View File

@@ -1,440 +0,0 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
t.Logf("Trace results: %v", trace.Results)
actualStages := make([]PacketStage, 0, len(trace.Results))
for _, result := range trace.Results {
actualStages = append(actualStages, result.Stage)
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
}
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
}
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
require.NotEmpty(t, trace.Results, "Trace should have results")
lastResult := trace.Results[len(trace.Results)-1]
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
}
func TestTracePacket(t *testing.T) {
setupTracerTest := func(statefulMode bool) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
m, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
if !statefulMode {
m.stateful = false
}
return m
}
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
builder := &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: protocol,
SrcPort: srcPort,
DstPort: dstPort,
Direction: direction,
}
if protocol == "tcp" {
builder.TCPState = &TCPState{SYN: true}
}
return builder
}
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
return &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: "icmp",
ICMPType: icmpType,
ICMPCode: icmpCode,
Direction: direction,
}
}
testCases := []struct {
name string
setup func(*Manager)
packetBuilder func() *PacketBuilder
expectedStages []PacketStage
expectedAllow bool
}{
{
name: "LocalTraffic_ACLAllowed",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_ACLDenied",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "LocalTraffic_WithForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = true
m.forwarder.Store(&forwarder.Forwarder{})
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_WithoutForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLAllowed",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.forwarder.Store(&forwarder.Forwarder{})
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLDenied",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "RoutedTraffic_NativeRouter",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_RoutingDisabled",
setup: func(m *Manager) {
m.routingEnabled.Store(false)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageCompleted,
},
expectedAllow: false,
},
{
name: "ConnectionTracking_Hit",
setup: func(m *Manager) {
srcIP := netip.MustParseAddr("100.10.0.100")
dstIP := netip.MustParseAddr("1.1.1.1")
srcPort := uint16(12345)
dstPort := uint16(80)
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
},
packetBuilder: func() *PacketBuilder {
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
pb.TCPState = &TCPState{SYN: true, ACK: true}
return pb
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageCompleted,
},
expectedAllow: true,
},
{
name: "OutboundTraffic",
setup: func(m *Manager) {
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPEchoRequest",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPDestinationUnreachable",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithoutHook",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "StatefulDisabled_NoTracking",
setup: func(m *Manager) {
m.stateful = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m := setupTracerTest(true)
tc.setup(m)
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
"100.10.0.100 should be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
"172.17.0.2 should not be recognized as a local IP")
pb := tc.packetBuilder()
trace, err := m.TracePacketFromBuilder(pb)
require.NoError(t, err)
verifyTraceStages(t, trace, tc.expectedStages)
verifyFinalDisposition(t, trace, tc.expectedAllow)
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -93,7 +93,8 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: false,
setupFunc: func(m *Manager) {
// Single rule allowing all traffic
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
fw.ActionAccept, "", "allow all")
require.NoError(b, err)
},
desc: "Baseline: Single 'allow all' rule without connection tracking",
@@ -113,15 +114,10 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering(
nil,
ip,
fw.ProtocolTCP,
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
&fw.Port{Values: []uint16{uint16(1024 + i)}},
&fw.Port{Values: []uint16{80}},
fw.ActionAccept,
"",
)
fw.ActionAccept, "", "explicit return")
require.NoError(b, err)
}
},
@@ -132,15 +128,8 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: true,
setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(
nil,
net.ParseIP("0.0.0.0"),
fw.ProtocolTCP,
nil,
nil,
fw.ActionDrop,
"",
)
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
fw.ActionDrop, "", "default drop")
require.NoError(b, err)
},
desc: "Connection tracking with established connections",
@@ -169,9 +158,9 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
require.NoError(b, manager.Reset(nil))
})
manager.wgNetwork = &net.IPNet{
@@ -193,13 +182,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection
if sc.stateful {
manager.processOutgoingHooks(outbound, 0)
manager.processOutgoingHooks(outbound)
}
// Measure inbound packet processing
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.dropFilter(inbound)
}
})
}
@@ -214,9 +203,9 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
require.NoError(b, manager.Reset(nil))
})
manager.wgNetwork = &net.IPNet{
@@ -230,7 +219,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound, 0)
manager.processOutgoingHooks(outbound)
}
// Test packet
@@ -238,11 +227,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection
manager.processOutgoingHooks(testOut, 0)
manager.processOutgoingHooks(testOut)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(testIn, 0)
manager.dropFilter(testIn)
}
})
}
@@ -262,9 +251,9 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
require.NoError(b, manager.Reset(nil))
})
manager.wgNetwork = &net.IPNet{
@@ -278,12 +267,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established {
manager.processOutgoingHooks(outbound, 0)
manager.processOutgoingHooks(outbound)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.dropFilter(inbound)
}
})
}
@@ -461,9 +450,9 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
require.NoError(b, manager.Reset(nil))
})
// Setup scenario
@@ -477,25 +466,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound, 0)
manager.processOutgoingHooks(outbound)
// For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" {
// SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.processOutgoingHooks(syn)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.dropFilter(synack)
// ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.processOutgoingHooks(ack)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0)
manager.dropFilter(inbound)
}
})
}
@@ -588,9 +577,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
require.NoError(b, manager.Reset(nil))
})
manager.SetNetwork(&net.IPNet{
@@ -601,7 +590,10 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@@ -624,17 +616,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.processOutgoingHooks(syn)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.dropFilter(synack)
// ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.processOutgoingHooks(ack)
}
// Prepare test packets simulating bidirectional traffic
@@ -655,9 +647,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic
// First outbound data
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.processOutgoingHooks(outPackets[connIdx])
// Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx])
}
})
}
@@ -676,9 +668,9 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
require.NoError(b, manager.Reset(nil))
})
manager.SetNetwork(&net.IPNet{
@@ -689,7 +681,10 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@@ -761,19 +756,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Connection establishment
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck)
manager.processOutgoingHooks(p.ack)
// Data transfer
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response)
// Connection teardown
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer)
manager.dropFilter(p.finServer)
manager.processOutgoingHooks(p.ackClient)
}
})
}
@@ -792,9 +787,9 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
require.NoError(b, manager.Reset(nil))
})
manager.SetNetwork(&net.IPNet{
@@ -804,7 +799,10 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@@ -826,15 +824,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0)
manager.processOutgoingHooks(syn)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0)
manager.dropFilter(synack)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0)
manager.processOutgoingHooks(ack)
}
// Pre-generate test packets
@@ -856,8 +854,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++
// Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0)
manager.processOutgoingHooks(outPackets[connIdx])
manager.dropFilter(inPackets[connIdx])
}
})
})
@@ -877,9 +875,9 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
require.NoError(b, manager.Reset(nil))
})
manager.SetNetwork(&net.IPNet{
@@ -888,7 +886,10 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
})
if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@@ -950,17 +951,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx]
// Full connection lifecycle
manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck)
manager.processOutgoingHooks(p.ack)
manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response)
manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0)
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer)
manager.dropFilter(p.finServer)
manager.processOutgoingHooks(p.ackClient)
}
})
})
@@ -1032,7 +1033,14 @@ func BenchmarkRouteACLs(b *testing.B) {
}
for _, r := range rules {
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
_, err := manager.AddRouteFiltering(
r.sources,
r.dest,
r.proto,
nil,
r.port,
fw.ActionAccept,
)
if err != nil {
b.Fatal(err)
}
@@ -1054,8 +1062,8 @@ func BenchmarkRouteACLs(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
srcIP := net.ParseIP(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
}
}

View File

@@ -12,9 +12,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestPeerACLFiltering(t *testing.T) {
@@ -26,20 +26,20 @@ func TestPeerACLFiltering(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: localIP,
Network: wgNet,
}
},
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false)
require.NoError(t, err)
require.NotNil(t, manager)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
require.NoError(t, manager.Reset(nil))
})
manager.wgNetwork = wgNet
@@ -192,20 +192,20 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.DropIncoming(packet, 0)
isDropped := manager.DropIncoming(packet)
require.True(t, isDropped, "Packet should be dropped when no rules exist")
})
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rules, err := manager.AddPeerFiltering(
nil,
net.ParseIP(tc.ruleIP),
tc.ruleProto,
tc.ruleSrcPort,
tc.ruleDstPort,
tc.ruleAction,
"",
tc.name,
)
require.NoError(t, err)
require.NotEmpty(t, rules)
@@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
})
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.DropIncoming(packet, 0)
isDropped := manager.DropIncoming(packet)
require.Equal(t, tc.shouldBeBlocked, isDropped)
})
}
@@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: localIP,
Network: wgNet,
}
@@ -302,15 +302,14 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
},
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(tb, manager.EnableRouting())
manager, err := Create(ifaceMock, false)
require.NoError(tb, err)
require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter.Load())
require.True(tb, manager.routingEnabled)
require.False(tb, manager.nativeRouter)
tb.Cleanup(func() {
require.NoError(tb, manager.Close(nil))
require.NoError(tb, manager.Reset(nil))
})
return manager
@@ -803,7 +802,6 @@ func TestRouteACLFiltering(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rule, err := manager.AddRouteFiltering(
nil,
tc.rule.sources,
tc.rule.dest,
tc.rule.proto,
@@ -818,12 +816,12 @@ func TestRouteACLFiltering(t *testing.T) {
require.NoError(t, manager.DeleteRouteRule(rule))
})
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
srcIP := net.ParseIP(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP)
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed)
})
}
@@ -986,7 +984,6 @@ func TestRouteACLOrder(t *testing.T) {
var rules []fw.Rule
for _, r := range tc.rules {
rule, err := manager.AddRouteFiltering(
nil,
r.sources,
r.dest,
r.proto,
@@ -1006,10 +1003,10 @@ func TestRouteACLOrder(t *testing.T) {
})
for i, p := range tc.packets {
srcIP := netip.MustParseAddr(p.srcIP)
dstIP := netip.MustParseAddr(p.dstIP)
srcIP := net.ParseIP(p.srcIP)
dstIP := net.ParseIP(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
}
})

View File

@@ -1,10 +1,8 @@
package uspfilter
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"testing"
"time"
@@ -18,17 +16,15 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error
AddressFunc func() wgaddr.Address
AddressFunc func() iface.WGAddress
GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice
}
@@ -54,9 +50,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
return i.SetFilterFunc(iface)
}
func (i *IFaceMock) Address() wgaddr.Address {
func (i *IFaceMock) Address() iface.WGAddress {
if i.AddressFunc == nil {
return wgaddr.Address{}
return iface.WGAddress{}
}
return i.AddressFunc()
}
@@ -66,7 +62,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -86,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -96,8 +92,9 @@ func TestManagerAddPeerFiltering(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
comment := "Test rule"
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -119,25 +116,26 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
}
ip := netip.MustParseAddr("192.168.1.1")
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
comment := "Test rule 2"
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@@ -151,7 +149,7 @@ func TestManagerDeleteRule(t *testing.T) {
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; ok {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@@ -162,7 +160,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name string
in bool
expDir fw.RuleDirection
ip netip.Addr
ip net.IP
dPort uint16
hook func([]byte) bool
expectedID string
@@ -171,7 +169,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: netip.MustParseAddr("10.168.0.1"),
ip: net.IPv4(10, 168, 0, 1),
dPort: 8000,
hook: func([]byte) bool { return true },
},
@@ -179,7 +177,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: netip.MustParseAddr("::1"),
ip: net.IPv6loopback,
dPort: 9000,
hook: func([]byte) bool { return false },
},
@@ -189,18 +187,18 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false)
require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule PeerRule
if tt.in {
if len(manager.incomingRules[tt.ip]) != 1 {
if len(manager.incomingRules[tt.ip.String()]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return
}
for _, rule := range manager.incomingRules[tt.ip] {
for _, rule := range manager.incomingRules[tt.ip.String()] {
addedRule = rule
}
} else {
@@ -208,12 +206,12 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return
}
for _, rule := range manager.outgoingRules[tt.ip] {
for _, rule := range manager.outgoingRules[tt.ip.String()] {
addedRule = rule
}
}
if tt.ip.Compare(addedRule.ip) != 0 {
if !tt.ip.Equal(addedRule.ip) {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
@@ -238,7 +236,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -248,14 +246,15 @@ func TestManagerReset(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
comment := "Test rule"
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
err = m.Close(nil)
err = m.Reset(nil)
if err != nil {
t.Errorf("failed to reset Manager: %v", err)
return
@@ -269,8 +268,8 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
@@ -280,7 +279,7 @@ func TestNotMatchByIP(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -293,8 +292,9 @@ func TestNotMatchByIP(t *testing.T) {
ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP
action := fw.ActionAccept
comment := "Test rule"
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -328,12 +328,12 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if m.dropFilter(buf.Bytes(), 0) {
if m.dropFilter(buf.Bytes()) {
t.Errorf("expected packet to be accepted")
return
}
if err = m.Close(nil); err != nil {
if err = m.Reset(nil); err != nil {
t.Errorf("failed to reset Manager: %v", err)
return
}
@@ -347,17 +347,17 @@ func TestRemovePacketHook(t *testing.T) {
}
// creating manager instance
manager, err := Create(iface, false, flowLogger)
manager, err := Create(iface, false)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
defer func() {
require.NoError(t, manager.Close(nil))
require.NoError(t, manager.Reset(nil))
}()
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
@@ -393,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
@@ -401,9 +401,9 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
defer func() {
require.NoError(t, manager.Close(nil))
require.NoError(t, manager.Reset(nil))
}()
manager.decoders = sync.Pool{
@@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
netip.MustParseAddr("100.10.0.100"),
net.ParseIP("100.10.0.100"),
53,
func([]byte) bool {
hookCalled = true
@@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err)
// Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes(), 0)
result := manager.processOutgoingHooks(buf.Bytes())
require.True(t, result)
require.True(t, hookCalled)
@@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes(), 0)
result = manager.processOutgoingHooks(buf.Bytes())
require.False(t, result)
}
@@ -479,12 +479,12 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false)
require.NoError(t, err)
time.Sleep(time.Second)
defer func() {
if err := manager.Close(nil); err != nil {
if err := manager.Reset(nil); err != nil {
t.Errorf("clear the manager state: %v", err)
}
time.Sleep(time.Second)
@@ -494,7 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
}
@@ -506,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false)
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
@@ -515,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{
@@ -530,12 +530,12 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
},
}
defer func() {
require.NoError(t, manager.Close(nil))
require.NoError(t, manager.Reset(nil))
}()
// Set up packet parameters
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := netip.MustParseAddr("100.10.0.100")
srcIP := net.ParseIP("100.10.0.1")
dstIP := net.ParseIP("100.10.0.100")
srcPort := uint16(51334)
dstPort := uint16(53)
@@ -543,8 +543,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
outboundIPv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP.AsSlice(),
DstIP: dstIP.AsSlice(),
SrcIP: srcIP,
DstIP: dstIP,
Protocol: layers.IPProtocolUDP,
}
outboundUDP := &layers.UDP{
@@ -569,15 +569,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
drop := manager.DropOutgoing(outboundBuf.Bytes())
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
@@ -585,8 +585,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
inboundIPv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: dstIP.AsSlice(), // Original destination is now source
DstIP: srcIP.AsSlice(), // Original source is now destination
SrcIP: dstIP, // Original destination is now source
DstIP: srcIP, // Original source is now destination
Protocol: layers.IPProtocolUDP,
}
inboundUDP := &layers.UDP{
@@ -636,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
drop = manager.dropFilter(inboundBuf.Bytes())
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
@@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}
// Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
@@ -707,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), 0)
drop = manager.dropFilter(testBuf.Bytes())
require.True(t, drop, tc.description)
})
}

View File

@@ -5,6 +5,7 @@ import (
"net"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/pion/stun/v2"
@@ -13,8 +14,6 @@ import (
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type RecvMessage struct {
@@ -53,10 +52,9 @@ type ICEBind struct {
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
@@ -66,7 +64,6 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
}
rc := receiverCreator{
@@ -111,17 +108,35 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil
}
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock()
b.endpoints[fakeIP] = conn
b.endpoints[fakeAddr] = conn
b.endpointsMu.Unlock()
return fakeUDPAddr, nil
}
func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock()
defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeIP)
delete(b.endpoints, fakeAddr)
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
@@ -146,10 +161,9 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
@@ -261,6 +275,21 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message)
}

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"io"
"net"
"slices"
"strings"
"sync"
@@ -153,7 +152,46 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}
mux := &UDPMuxDefault{
var localAddrsForUnspecified []net.Addr
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if params.Net == nil {
var err error
if params.Net, err = stdnet.NewNet(); err != nil {
params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
return &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
@@ -165,55 +203,8 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return newBufferHolder(receiveMTU + maxAddrSize)
},
},
localAddrsForUnspecified: localAddrsForUnspecified,
}
mux.updateLocalAddresses()
return mux
}
func (m *UDPMuxDefault) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default:
m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if m.params.Net == nil {
var err error
if m.params.Net, err = stdnet.NewNet(); err != nil {
m.params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
m.mu.Lock()
m.localAddrsForUnspecified = localAddrsForUnspecified
m.mu.Unlock()
}
// LocalAddr returns the listening address of this UDPMuxDefault
@@ -223,12 +214,8 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
m.updateLocalAddresses()
m.mu.Lock()
defer m.mu.Unlock()
if len(m.localAddrsForUnspecified) > 0 {
return slices.Clone(m.localAddrsForUnspecified)
return m.localAddrsForUnspecified
}
return []net.Addr{m.LocalAddr()}
@@ -238,10 +225,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address
m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified)
m.mu.Unlock()
if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
return nil, fmt.Errorf("invalid address %s", addr.String())
}

View File

@@ -17,8 +17,6 @@ import (
"github.com/pion/logging"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// FilterFn is a function that filters out candidates based on the address.
@@ -43,7 +41,6 @@ type UniversalUDPMuxParams struct {
XORMappedAddrCacheTTL time.Duration
Net transport.Net
FilterFn FilterFn
WGAddress wgaddr.Address
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -67,7 +64,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
mux: m,
logger: params.Logger,
filterFn: params.FilterFn,
address: params.WGAddress,
}
// embed UDPMux
@@ -122,7 +118,6 @@ type udpConn struct {
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
address wgaddr.Address
}
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
@@ -164,11 +159,6 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil
}
if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {

View File

@@ -43,7 +43,13 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil
}
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -52,7 +58,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, ke
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: allowedIps,
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,
PresharedKey: preSharedKey,

View File

@@ -52,7 +52,13 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -61,7 +67,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, kee
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: allowedIps,
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
Endpoint: endpoint,
@@ -356,7 +362,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
}
func getFwmark() int {
if nbnet.AdvancedRouting() {
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
return nbnet.NetbirdFwmark
}
return 0

View File

@@ -3,23 +3,19 @@
package iface
import (
"golang.zx2c4.com/wireguard/tun/netstack"
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
}

View File

@@ -1,29 +1,29 @@
package wgaddr
package device
import (
"fmt"
"net"
)
// Address WireGuard parsed address
type Address struct {
// WGAddress WireGuard parsed address
type WGAddress struct {
IP net.IP
Network *net.IPNet
}
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) {
func ParseWGAddress(address string) (WGAddress, error) {
ip, network, err := net.ParseCIDR(address)
if err != nil {
return Address{}, err
return WGAddress{}, err
}
return Address{
return WGAddress{
IP: ip,
Network: network,
}, nil
}
func (addr Address) String() string {
func (addr WGAddress) String() string {
maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
}

View File

@@ -9,16 +9,14 @@ import (
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct {
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -32,7 +30,7 @@ type WGTunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
@@ -94,7 +92,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
@@ -124,7 +122,7 @@ func (t *WGTunDevice) DeviceName() string {
return t.name
}
func (t *WGTunDevice) WgAddress() wgaddr.Address {
func (t *WGTunDevice) WgAddress() WGAddress {
return t.address
}
@@ -132,10 +130,6 @@ func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
func (t *WGTunDevice) GetNet() *netstack.Net {
return nil
}
func routesToString(routes []string) string {
return strings.Join(routes, ";")
}

View File

@@ -9,16 +9,14 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type TunDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -30,7 +28,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -86,7 +84,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
@@ -107,7 +105,7 @@ func (t *TunDevice) Close() error {
return nil
}
func (t *TunDevice) WgAddress() wgaddr.Address {
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
@@ -145,7 +143,3 @@ func (t *TunDevice) assignAddr() error {
}
return nil
}
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -2,7 +2,6 @@ package device
import (
"net"
"net/netip"
"sync"
"golang.zx2c4.com/wireguard/tun"
@@ -11,16 +10,16 @@ import (
// PacketFilter interface for firewall abilities
type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte, size int) bool
DropOutgoing(packetData []byte) bool
// DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte, size int) bool
DropIncoming(packetData []byte) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
@@ -58,7 +57,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
}
for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
@@ -82,7 +81,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:], len(buf)) {
if !filter.DropIncoming(buf[offset:]) {
filteredBufs = append(filteredBufs, buf)
dropped++
}

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun)
wrapped.filter = filter

View File

@@ -10,16 +10,14 @@ import (
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type TunDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
iceBind *bind.ICEBind
@@ -31,7 +29,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -121,11 +119,11 @@ func (t *TunDevice) Close() error {
return nil
}
func (t *TunDevice) WgAddress() wgaddr.Address {
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
func (t *TunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
@@ -133,7 +131,3 @@ func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -10,17 +10,15 @@ import (
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock"
)
type TunKernelDevice struct {
name string
address wgaddr.Address
address WGAddress
wgPort int
key string
mtu int
@@ -35,7 +33,7 @@ type TunKernelDevice struct {
filterFn bind.FilterFn
}
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{
ctx: ctx,
@@ -100,10 +98,9 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
}
mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
@@ -114,7 +111,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil
}
func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
@@ -147,7 +144,7 @@ func (t *TunKernelDevice) Close() error {
return closErr
}
func (t *TunKernelDevice) WgAddress() wgaddr.Address {
func (t *TunKernelDevice) WgAddress() WGAddress {
return t.address
}
@@ -168,7 +165,3 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(t.address)
}
func (t *TunKernelDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -8,18 +8,15 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/client/iface/netstack"
)
type TunNetstackDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -28,14 +25,12 @@ type TunNetstackDevice struct {
device *device.Device
filteredDevice *FilteredDevice
nsTun *nbnetstack.NetStackTun
nsTun *netstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
net *netstack.Net
}
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
@@ -48,19 +43,13 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
}
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr)
tunIface, net, err := t.nsTun.Create()
log.Info("create netstack tun interface")
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create()
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.filteredDevice = newDeviceFilter(tunIface)
t.net = net
t.device = device.NewDevice(
t.filteredDevice,
@@ -98,7 +87,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
return nil
}
@@ -117,7 +106,7 @@ func (t *TunNetstackDevice) Close() error {
return nil
}
func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
func (t *TunNetstackDevice) WgAddress() WGAddress {
return t.address
}
@@ -133,7 +122,3 @@ func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
func (t *TunNetstackDevice) Device() *device.Device {
return t.device
}
func (t *TunNetstackDevice) GetNet() *netstack.Net {
return t.net
}

View File

@@ -8,16 +8,14 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type USPDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -29,7 +27,7 @@ type USPDevice struct {
configurer WGConfigurer
}
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode")
return &USPDevice{
@@ -94,7 +92,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
func (t *USPDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
@@ -114,7 +112,7 @@ func (t *USPDevice) Close() error {
return nil
}
func (t *USPDevice) WgAddress() wgaddr.Address {
func (t *USPDevice) WgAddress() WGAddress {
return t.address
}
@@ -137,7 +135,3 @@ func (t *USPDevice) assignAddr() error {
return link.assignAddr(t.address)
}
func (t *USPDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -8,19 +8,17 @@ import (
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
type TunDevice struct {
name string
address wgaddr.Address
address WGAddress
port int
key string
mtu int
@@ -33,7 +31,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -119,7 +117,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
@@ -140,7 +138,7 @@ func (t *TunDevice) Close() error {
}
return nil
}
func (t *TunDevice) WgAddress() wgaddr.Address {
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
@@ -176,7 +174,3 @@ func (t *TunDevice) assignAddr() error {
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
}
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -11,7 +11,7 @@ import (
type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error

View File

@@ -6,7 +6,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgLink struct {
@@ -57,7 +56,7 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address wgaddr.Address) error {
func (l *wgLink) assignAddr(address WGAddress) error {
link, err := freebsd.LinkByName(l.name)
if err != nil {
return fmt.Errorf("link by name: %w", err)

View File

@@ -8,8 +8,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgLink struct {
@@ -92,7 +90,7 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address wgaddr.Address) error {
func (l *wgLink) assignAddr(address WGAddress) error {
//delete existing addresses
list, err := netlink.AddrList(l, 0)
if err != nil {

View File

@@ -3,21 +3,17 @@ package iface
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
}

View File

@@ -3,14 +3,12 @@ package iface
import (
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
wgdevice "golang.zx2c4.com/wireguard/device"
@@ -19,7 +17,6 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
@@ -29,6 +26,8 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault
)
type WGAddress = device.WGAddress
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
@@ -71,7 +70,7 @@ func (w *WGIface) Name() string {
}
// Address returns the interface address
func (w *WGIface) Address() wgaddr.Address {
func (w *WGIface) Address() device.WGAddress {
return w.tun.WgAddress()
}
@@ -102,7 +101,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
addr, err := wgaddr.ParseWGAddress(newAddr)
addr, err := device.ParseWGAddress(newAddr)
if err != nil {
return err
}
@@ -112,13 +111,12 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock()
defer w.mu.Unlock()
netIPNets := prefixesToIPNets(allowedIps)
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
@@ -243,22 +241,3 @@ func (w *WGIface) waitUntilRemoved() error {
}
}
}
// GetNet returns the netstack.Net for the netstack device
func (w *WGIface) GetNet() *netstack.Net {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.GetNet()
}
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

117
client/iface/iface_moc.go Normal file
View File

@@ -0,0 +1,117 @@
package iface
import (
"net"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
type MockWGIface struct {
CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool
NameFunc func() string
AddressFunc func() device.WGAddress
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
CloseFunc func() error
SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice
GetWGDeviceFunc func() *wgdevice.Device
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
return m.GetInterfaceGUIDStringFunc()
}
func (m *MockWGIface) Create() error {
return m.CreateFunc()
}
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
return m.CreateOnAndroidFunc(routeRange, ip, domains)
}
func (m *MockWGIface) IsUserspaceBind() bool {
return m.IsUserspaceBindFunc()
}
func (m *MockWGIface) Name() string {
return m.NameFunc()
}
func (m *MockWGIface) Address() device.WGAddress {
return m.AddressFunc()
}
func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc()
}
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
return m.UpFunc()
}
func (m *MockWGIface) UpdateAddr(newAddr string) error {
return m.UpdateAddrFunc(newAddr)
}
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
func (m *MockWGIface) RemovePeer(peerKey string) error {
return m.RemovePeerFunc(peerKey)
}
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
return m.AddAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) Close() error {
return m.CloseFunc()
}
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
return m.SetFilterFunc(filter)
}
func (m *MockWGIface) GetFilter() device.PacketFilter {
return m.GetFilterFunc()
}
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
return m.GetDeviceFunc()
}
func (m *MockWGIface) GetWGDevice() *wgdevice.Device {
return m.GetWGDeviceFunc()
}
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey)
}
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
return m.GetProxyFunc()
}

View File

@@ -3,18 +3,17 @@ package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
userspaceBind: true,

View File

@@ -6,18 +6,17 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -5,18 +5,17 @@ package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),

View File

@@ -8,13 +8,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
@@ -22,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
@@ -35,7 +34,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)

View File

@@ -4,17 +4,16 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -373,12 +373,12 @@ func Test_UpdatePeer(t *testing.T) {
t.Fatal(err)
}
keepAlive := 15 * time.Second
allowedIP := netip.MustParsePrefix("10.99.99.10/32")
allowedIP := "10.99.99.10/32"
endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9900")
if err != nil {
t.Fatal(err)
}
err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, endpoint, nil)
err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, endpoint, nil)
if err != nil {
t.Fatal(err)
}
@@ -396,7 +396,7 @@ func Test_UpdatePeer(t *testing.T) {
var foundAllowedIP bool
for _, aip := range peer.AllowedIPs {
if aip.String() == allowedIP.String() {
if aip.String() == allowedIP {
foundAllowedIP = true
break
}
@@ -443,8 +443,9 @@ func Test_RemovePeer(t *testing.T) {
t.Fatal(err)
}
keepAlive := 15 * time.Second
allowedIP := netip.MustParsePrefix("10.99.99.14/32")
err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, nil, nil)
allowedIP := "10.99.99.14/32"
err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -461,12 +462,12 @@ func Test_RemovePeer(t *testing.T) {
func Test_ConnectPeers(t *testing.T) {
peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400)
peer1wgIP := netip.MustParsePrefix("10.99.99.17/30")
peer1wgIP := "10.99.99.17/30"
peer1Key, _ := wgtypes.GeneratePrivateKey()
peer1wgPort := 33100
peer2ifaceName := "utun500"
peer2wgIP := netip.MustParsePrefix("10.99.99.18/30")
peer2wgIP := "10.99.99.18/30"
peer2Key, _ := wgtypes.GeneratePrivateKey()
peer2wgPort := 33200
@@ -481,7 +482,7 @@ func Test_ConnectPeers(t *testing.T) {
optsPeer1 := WGIFaceOpts{
IFaceName: peer1ifaceName,
Address: peer1wgIP.String(),
Address: peer1wgIP,
WGPort: peer1wgPort,
WGPrivKey: peer1Key.String(),
MTU: DefaultMTU,
@@ -521,7 +522,7 @@ func Test_ConnectPeers(t *testing.T) {
optsPeer2 := WGIFaceOpts{
IFaceName: peer2ifaceName,
Address: peer2wgIP.String(),
Address: peer2wgIP,
WGPort: peer2wgPort,
WGPrivKey: peer2Key.String(),
MTU: DefaultMTU,
@@ -557,11 +558,11 @@ func Test_ConnectPeers(t *testing.T) {
}
}()
err = iface1.UpdatePeer(peer2Key.PublicKey().String(), []netip.Prefix{peer2wgIP}, keepAlive, peer2endpoint, nil)
err = iface1.UpdatePeer(peer2Key.PublicKey().String(), peer2wgIP, keepAlive, peer2endpoint, nil)
if err != nil {
t.Fatal(err)
}
err = iface2.UpdatePeer(peer1Key.PublicKey().String(), []netip.Prefix{peer1wgIP}, keepAlive, peer1endpoint, nil)
err = iface2.UpdatePeer(peer1Key.PublicKey().String(), peer1wgIP, keepAlive, peer1endpoint, nil)
if err != nil {
t.Fatal(err)
}

Some files were not shown because too many files have changed in this diff Show More