mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-01 15:14:03 -04:00
Compare commits
72 Commits
set-comman
...
userspace-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
19178b59ec | ||
|
|
48f58d776c | ||
|
|
4d635e3c2f | ||
|
|
a0ca3edb9f | ||
|
|
0837864cfc | ||
|
|
e3d4f9819f | ||
|
|
da43d33540 | ||
|
|
b951fb4aec | ||
|
|
862d548d4d | ||
|
|
9b5c0439e9 | ||
|
|
21a3679590 | ||
|
|
77afcc8454 | ||
|
|
22991b3963 | ||
|
|
ea6c947f5d | ||
|
|
8dce13113d | ||
|
|
a625f90ea8 | ||
|
|
1c00870ca6 | ||
|
|
1296ecf96e | ||
|
|
8430c37dd6 | ||
|
|
648b22aca1 | ||
|
|
d31543cb12 | ||
|
|
af46f259ac | ||
|
|
01957a305d | ||
|
|
706f98c1f1 | ||
|
|
6335ef8b48 | ||
|
|
daf935942c | ||
|
|
28f5cd523a | ||
|
|
2060242092 | ||
|
|
5ea39dfe8a | ||
|
|
4a189a87ce | ||
|
|
fe7a2aa263 | ||
|
|
290e6992a8 | ||
|
|
474fb33305 | ||
|
|
766e0cccc9 | ||
|
|
7dfe7e426e | ||
|
|
eaadb75144 | ||
|
|
0b116b3941 | ||
|
|
f69dd6fb62 | ||
|
|
62a20f5f1a | ||
|
|
a6ad4dcf22 | ||
|
|
f26b418e83 | ||
|
|
3ce39905c6 | ||
|
|
979fe6bb6a | ||
|
|
c68be6b61b | ||
|
|
fc799effda | ||
|
|
955b2b98e1 | ||
|
|
9490e9095b | ||
|
|
d711172f67 | ||
|
|
0c2fa38e26 | ||
|
|
88b420da6d | ||
|
|
2930288f2d | ||
|
|
0b9854b2b1 | ||
|
|
f772a21f37 | ||
|
|
e912f2d7c0 | ||
|
|
568d064089 | ||
|
|
911f86ded8 | ||
|
|
2b8092dfad | ||
|
|
c3c6afa37b | ||
|
|
fa27369b59 | ||
|
|
657413b8a6 | ||
|
|
d85e57e819 | ||
|
|
7667886794 | ||
|
|
a12a9ac290 | ||
|
|
ed22d79f04 | ||
|
|
509b4e2132 | ||
|
|
fb1a10755a | ||
|
|
9feaa8d767 | ||
|
|
6a97d44d5d | ||
|
|
d2616544fe | ||
|
|
fad82ee65c | ||
|
|
b43a8c56df | ||
|
|
4199da4a45 |
@@ -1,27 +0,0 @@
|
||||
# More info around this file at https://www.git-town.com/configuration-file
|
||||
|
||||
[branches]
|
||||
main = "main"
|
||||
perennials = []
|
||||
perennial-regex = ""
|
||||
|
||||
[create]
|
||||
new-branch-type = "feature"
|
||||
push-new-branches = false
|
||||
|
||||
[hosting]
|
||||
dev-remote = "origin"
|
||||
# platform = ""
|
||||
# origin-hostname = ""
|
||||
|
||||
[ship]
|
||||
delete-tracking-branch = false
|
||||
strategy = "squash-merge"
|
||||
|
||||
[sync]
|
||||
feature-strategy = "merge"
|
||||
perennial-strategy = "rebase"
|
||||
prototype-strategy = "merge"
|
||||
push-hook = true
|
||||
tags = true
|
||||
upstream = false
|
||||
32
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
32
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -31,27 +31,14 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
|
||||
|
||||
`netbird version`
|
||||
|
||||
**Is any other VPN software installed?**
|
||||
**NetBird status -dA output:**
|
||||
|
||||
If yes, which one?
|
||||
If applicable, add the `netbird status -dA' command output.
|
||||
|
||||
**Debug output**
|
||||
|
||||
To help us resolve the problem, please attach the following anonymized status output
|
||||
|
||||
netbird status -dA
|
||||
|
||||
Create and upload a debug bundle, and share the returned file key:
|
||||
|
||||
netbird debug for 1m -AS -U
|
||||
|
||||
*Uploaded files are automatically deleted after 30 days.*
|
||||
|
||||
|
||||
Alternatively, create the file only and attach it here manually:
|
||||
|
||||
netbird debug for 1m -AS
|
||||
**Do you face any (non-mobile) client issues?**
|
||||
|
||||
Please provide the file created by `netbird debug for 1m -AS`.
|
||||
We advise reviewing the anonymized files for any remaining PII.
|
||||
|
||||
**Screenshots**
|
||||
|
||||
@@ -60,12 +47,3 @@ If applicable, add screenshots to help explain your problem.
|
||||
**Additional context**
|
||||
|
||||
Add any other context about the problem here.
|
||||
|
||||
**Have you tried these troubleshooting steps?**
|
||||
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
|
||||
- [ ] Checked for newer NetBird versions
|
||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||
- [ ] Restarted the NetBird client
|
||||
- [ ] Disabled other VPN software
|
||||
- [ ] Checked firewall settings
|
||||
|
||||
|
||||
6
.github/pull_request_template.md
vendored
6
.github/pull_request_template.md
vendored
@@ -2,10 +2,6 @@
|
||||
|
||||
## Issue ticket number and link
|
||||
|
||||
## Stack
|
||||
|
||||
<!-- branch-stack -->
|
||||
|
||||
### Checklist
|
||||
- [ ] Is it a bug fix
|
||||
- [ ] Is a typo/documentation fix
|
||||
@@ -13,5 +9,3 @@
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] Extended the README / documentation, if necessary
|
||||
|
||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||
|
||||
21
.github/workflows/git-town.yml
vendored
21
.github/workflows/git-town.yml
vendored
@@ -1,21 +0,0 @@
|
||||
name: Git Town
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- '**'
|
||||
|
||||
jobs:
|
||||
git-town:
|
||||
name: Display the branch stack
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: git-town/action@v1
|
||||
with:
|
||||
skip-single-stacks: true
|
||||
6
.github/workflows/golang-test-darwin.yml
vendored
6
.github/workflows/golang-test-darwin.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: "Darwin"
|
||||
name: Test Code Darwin
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -12,7 +12,9 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
strategy:
|
||||
matrix:
|
||||
store: ['sqlite']
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Install Go
|
||||
|
||||
14
.github/workflows/golang-test-freebsd.yml
vendored
14
.github/workflows/golang-test-freebsd.yml
vendored
@@ -1,4 +1,5 @@
|
||||
name: "FreeBSD"
|
||||
|
||||
name: Test Code FreeBSD
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -12,7 +13,6 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -22,20 +22,14 @@ jobs:
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.2"
|
||||
release: "14.1"
|
||||
prepare: |
|
||||
pkg install -y curl pkgconf xorg
|
||||
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
|
||||
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
pkg install -y go pkgconf xorg
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
run: |
|
||||
set -e -x
|
||||
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
||||
time go build -o netbird client/main.go
|
||||
# check all component except management, since we do not support management server on freebsd
|
||||
time go test -timeout 1m -failfast ./base62/...
|
||||
|
||||
363
.github/workflows/golang-test-linux.yml
vendored
363
.github/workflows/golang-test-linux.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Linux
|
||||
name: Test Code Linux
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -12,21 +12,11 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-cache:
|
||||
name: "Build Cache"
|
||||
runs-on: ubuntu-22.04
|
||||
outputs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
management:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
@@ -48,6 +38,7 @@ jobs:
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
@@ -98,7 +89,6 @@ jobs:
|
||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -144,174 +134,14 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [build-cache]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
id: go-env
|
||||
run: |
|
||||
echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT
|
||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
${{ steps.go-env.outputs.cache_dir }}
|
||||
${{ steps.go-env.outputs.modcache_dir }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Run tests in container
|
||||
env:
|
||||
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
|
||||
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
|
||||
run: |
|
||||
CONTAINER_GOCACHE="/root/.cache/go-build"
|
||||
CONTAINER_GOMODCACHE="/go/pkg/mod"
|
||||
|
||||
docker run --rm \
|
||||
--cap-add=NET_ADMIN \
|
||||
--privileged \
|
||||
-v $PWD:/app \
|
||||
-w /app \
|
||||
-v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \
|
||||
-v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \
|
||||
-e CGO_ENABLED=1 \
|
||||
-e CI=true \
|
||||
-e DOCKER_CI=true \
|
||||
-e GOARCH=${GOARCH_TARGET} \
|
||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||
golang:1.23-alpine \
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
name: "Relay / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test \
|
||||
-exec 'sudo' \
|
||||
-timeout 10m ./signal/...
|
||||
|
||||
test_signal:
|
||||
name: "Signal / Unit"
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
go test \
|
||||
-exec 'sudo' \
|
||||
-timeout 10m ./signal/...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
|
||||
|
||||
test_management:
|
||||
name: "Management / Unit"
|
||||
needs: [ build-cache ]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
@@ -339,6 +169,13 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -357,23 +194,15 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=devcert \
|
||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||
-timeout 20m ./management/...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
|
||||
|
||||
benchmark:
|
||||
name: "Management / Benchmark"
|
||||
needs: [ build-cache ]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
@@ -400,6 +229,13 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -418,52 +254,17 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags devcert -run=^$ -bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./management/...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
needs: [ build-cache ]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
run: docker network create promnet
|
||||
|
||||
- name: Start Prometheus Pushgateway
|
||||
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
|
||||
|
||||
- name: Start Prometheus (for Pushgateway forwarding)
|
||||
run: |
|
||||
echo '
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
scrape_configs:
|
||||
- job_name: "pushgateway"
|
||||
static_configs:
|
||||
- targets: ["pushgateway:9091"]
|
||||
remote_write:
|
||||
- url: ${{ secrets.GRAFANA_URL }}
|
||||
basic_auth:
|
||||
username: ${{ secrets.GRAFANA_USER }}
|
||||
password: ${{ secrets.GRAFANA_API_KEY }}
|
||||
' > prometheus.yml
|
||||
|
||||
docker run -d --name prometheus --network promnet \
|
||||
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||
-p 9090:9090 \
|
||||
prom/prometheus
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
@@ -488,6 +289,13 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -504,27 +312,16 @@ jobs:
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
GIT_BRANCH=${{ github.ref_name }} \
|
||||
go test -tags=benchmark \
|
||||
-run=^$ \
|
||||
-bench=. \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||
-timeout 20m ./management/...
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management)
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
needs: [ build-cache ]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
arch: [ '386','amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
@@ -552,6 +349,13 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install 32-bit libpcap
|
||||
if: matrix.arch == '386'
|
||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
@@ -559,10 +363,85 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management)
|
||||
|
||||
test_client_on_docker:
|
||||
needs: [ build-cache ]
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
go test -tags=integration \
|
||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||
-timeout 20m ./management/...
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
${{ env.modcache }}
|
||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-gotest-cache-
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Generate Shared Sock Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||
|
||||
- name: Generate RouteManager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
|
||||
|
||||
- name: Generate SystemOps Test bin
|
||||
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
|
||||
|
||||
- name: Generate nftables Manager Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||
|
||||
- name: Generate Engine Test bin
|
||||
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
||||
|
||||
- name: Generate Peer Test bin
|
||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
|
||||
|
||||
- run: chmod +x *testing.bin
|
||||
|
||||
- name: Run Shared Sock tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Iface tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
|
||||
|
||||
- name: Run RouteManager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run SystemOps tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run nftables Manager tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Engine tests in docker with file store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Engine tests in docker with sqlite store
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
- name: Run Peer tests in docker
|
||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
||||
|
||||
3
.github/workflows/golang-test-windows.yml
vendored
3
.github/workflows/golang-test-windows.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: "Windows"
|
||||
name: Test Code Windows
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -14,7 +14,6 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: "Client / Unit"
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
14
.github/workflows/golangci-lint.yml
vendored
14
.github/workflows/golangci-lint.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Lint
|
||||
name: golangci-lint
|
||||
on: [pull_request]
|
||||
|
||||
permissions:
|
||||
@@ -19,21 +19,15 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest, windows-latest, ubuntu-latest]
|
||||
include:
|
||||
- os: macos-latest
|
||||
display_name: Darwin
|
||||
- os: windows-latest
|
||||
display_name: Windows
|
||||
- os: ubuntu-latest
|
||||
display_name: Linux
|
||||
name: ${{ matrix.display_name }}
|
||||
name: lint
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
name: Mobile
|
||||
name: Mobile build validation
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -12,7 +12,6 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
android_build:
|
||||
name: "Android / Build"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -48,7 +47,6 @@ jobs:
|
||||
CGO_ENABLED: 0
|
||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||
ios_build:
|
||||
name: "iOS / Build"
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
||||
21
.github/workflows/release.yml
vendored
21
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.20"
|
||||
SIGN_PIPE_VER: "v0.0.18"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -65,20 +65,13 @@ jobs:
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.CI_DOCKER_PUSH_GITHUB_TOKEN }}
|
||||
- name: Install OS build dependencies
|
||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
with:
|
||||
@@ -94,25 +87,25 @@ jobs:
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 7
|
||||
retention-days: 3
|
||||
- name: upload linux packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 7
|
||||
retention-days: 3
|
||||
- name: upload windows packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 7
|
||||
retention-days: 3
|
||||
- name: upload macos packages
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: macos-packages
|
||||
path: dist/netbird_darwin**
|
||||
retention-days: 7
|
||||
retention-days: 3
|
||||
|
||||
release_ui:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -157,7 +150,7 @@ jobs:
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
|
||||
@@ -134,7 +134,6 @@ jobs:
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
|
||||
run: |
|
||||
set -x
|
||||
@@ -173,15 +172,12 @@ jobs:
|
||||
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||
# check relay values
|
||||
grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
||||
grep '33445:33445' docker-compose.yml
|
||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||
grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||
grep DisablePromptLogin management.json | grep 'true'
|
||||
grep LoginFlag management.json | grep 0
|
||||
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
|
||||
|
||||
- name: Install modules
|
||||
run: go mod tidy
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,4 +29,3 @@ infrastructure_files/setup.env
|
||||
infrastructure_files/setup-*.env
|
||||
.vscode
|
||||
.DS_Store
|
||||
vendor/
|
||||
|
||||
@@ -103,7 +103,7 @@ linters:
|
||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||
- wastedassign # wastedassign finds wasted assignment statements
|
||||
issues:
|
||||
# Maximum count of issues with the same text.
|
||||
|
||||
198
.goreleaser.yaml
198
.goreleaser.yaml
@@ -96,20 +96,6 @@ builds:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird-upload
|
||||
dir: upload-server
|
||||
env: [CGO_ENABLED=0]
|
||||
binary: netbird-upload
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- arm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
universal_binaries:
|
||||
- id: netbird
|
||||
|
||||
@@ -149,7 +135,6 @@ nfpms:
|
||||
dockers:
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird
|
||||
goarch: amd64
|
||||
@@ -165,7 +150,6 @@ dockers:
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm64
|
||||
@@ -177,11 +161,10 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm
|
||||
@@ -194,12 +177,11 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
ids:
|
||||
- netbird
|
||||
goarch: amd64
|
||||
@@ -211,11 +193,9 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm64
|
||||
@@ -227,11 +207,9 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
ids:
|
||||
- netbird
|
||||
goarch: arm
|
||||
@@ -244,12 +222,10 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: amd64
|
||||
@@ -261,11 +237,10 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: arm64
|
||||
@@ -277,11 +252,10 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/relay:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-relay
|
||||
goarch: arm
|
||||
@@ -294,11 +268,10 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-signal
|
||||
goarch: amd64
|
||||
@@ -310,11 +283,10 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-signal
|
||||
goarch: arm64
|
||||
@@ -326,11 +298,10 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/signal:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-signal
|
||||
goarch: arm
|
||||
@@ -343,11 +314,10 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: amd64
|
||||
@@ -359,11 +329,10 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm64
|
||||
@@ -375,11 +344,10 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm
|
||||
@@ -392,11 +360,10 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: amd64
|
||||
@@ -408,11 +375,10 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm64
|
||||
@@ -424,12 +390,11 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
|
||||
- image_templates:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
||||
ids:
|
||||
- netbird-mgmt
|
||||
goarch: arm
|
||||
@@ -442,56 +407,7 @@ dockers:
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: amd64
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: arm64
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm64"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
- image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
ids:
|
||||
- netbird-upload
|
||||
goarch: arm
|
||||
goarm: 6
|
||||
use: buildx
|
||||
dockerfile: upload-server/Dockerfile
|
||||
build_flag_templates:
|
||||
- "--platform=linux/arm"
|
||||
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||
- "--label=maintainer=dev@netbird.io"
|
||||
docker_manifests:
|
||||
- name_template: netbirdio/netbird:{{ .Version }}
|
||||
@@ -559,95 +475,7 @@ docker_manifests:
|
||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- netbirdio/management:{{ .Version }}-debug-arm
|
||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||
- name_template: netbirdio/upload:{{ .Version }}
|
||||
image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: netbirdio/upload:latest
|
||||
image_templates:
|
||||
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- netbirdio/upload:{{ .Version }}-arm
|
||||
- netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/relay:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/signal:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/management:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/management:debug-latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
||||
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
|
||||
- name_template: ghcr.io/netbirdio/upload:latest
|
||||
image_templates:
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||
brews:
|
||||
- ids:
|
||||
- default
|
||||
|
||||
@@ -50,12 +50,10 @@ nfpms:
|
||||
- netbird-ui
|
||||
formats:
|
||||
- deb
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/build/netbird.desktop
|
||||
- src: client/ui/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/assets/netbird.png
|
||||
- src: client/ui/netbird-systemtray-connected.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
@@ -69,12 +67,10 @@ nfpms:
|
||||
- netbird-ui
|
||||
formats:
|
||||
- rpm
|
||||
scripts:
|
||||
postinstall: "release_files/ui-post-install.sh"
|
||||
contents:
|
||||
- src: client/ui/build/netbird.desktop
|
||||
- src: client/ui/netbird.desktop
|
||||
dst: /usr/share/applications/netbird.desktop
|
||||
- src: client/ui/assets/netbird.png
|
||||
- src: client/ui/netbird-systemtray-connected.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
|
||||
@@ -1,64 +1,148 @@
|
||||
## Contributor License Agreement
|
||||
# Contributor License Agreement
|
||||
|
||||
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
|
||||
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
|
||||
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
|
||||
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
|
||||
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
|
||||
of the terms and conditions outlined below. The Contributor further represents that they are authorized to
|
||||
complete this process as described herein.
|
||||
We are incredibly thankful for the contributions we receive from the community.
|
||||
We require our external contributors to sign a Contributor License Agreement ("CLA") in
|
||||
order to ensure that our projects remain licensed under Free and Open Source licenses such
|
||||
as BSD-3 while allowing NetBird to build a sustainable business.
|
||||
|
||||
NetBird is committed to having a true Open Source Software ("OSS") license for
|
||||
our software. A CLA enables NetBird to safely commercialize our products
|
||||
while keeping a standard OSS license with all the rights that license grants to users: the
|
||||
ability to use the project in their own projects or businesses, to republish modified
|
||||
source, or to completely fork the project.
|
||||
|
||||
This page gives a human-friendly summary of our CLA, details on why we require a CLA, how
|
||||
contributors can sign our CLA, and more. You may view the full legal CLA document (below).
|
||||
|
||||
# Human-friendly summary
|
||||
|
||||
This is a human-readable summary of (and not a substitute for) the full agreement (below).
|
||||
This highlights only some of key terms of the CLA. It has no legal value and you should
|
||||
carefully review all the terms of the actual CLA before agreeing.
|
||||
|
||||
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
|
||||
in commercial products.
|
||||
</li>
|
||||
|
||||
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
|
||||
license to use that patent including within commercial products. You also agree that you
|
||||
have permission to grant this license.
|
||||
</li>
|
||||
|
||||
<li>No Warranty or Support Obligations.
|
||||
By making a contribution, you are not obligating yourself to provide support for the
|
||||
contribution, and you are not taking on any warranty obligations or providing any
|
||||
assurances about how it will perform.
|
||||
</li>
|
||||
|
||||
The CLA does not change the terms of the standard open source license used by our software
|
||||
such as BSD-3 or MIT.
|
||||
You are still free to use our projects within your own projects or businesses, republish
|
||||
modified source, and more.
|
||||
Please reference the appropriate license for the project you're contributing to to learn
|
||||
more.
|
||||
|
||||
# Why require a CLA?
|
||||
|
||||
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
|
||||
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
|
||||
products.
|
||||
|
||||
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
|
||||
adopt our projects. At the same time, the CLA ensures that all contributions to our open source projects are licensed
|
||||
under the project's respective open source license, such as BSD-3.
|
||||
|
||||
Requiring a CLA is a common and well-accepted practice in open source. Major open source projects require CLAs such as
|
||||
Apache Software Foundation projects, Facebook projects (such as React), Google projects (including Go), Python, Django,
|
||||
and more. Each of these projects remains licensed under permissive OSS licenses such as MIT, Apache, BSD, and more.
|
||||
|
||||
# Signing the CLA
|
||||
|
||||
Open a pull request ("PR") to any of our open source projects to sign the CLA. A bot will comment on the PR asking you
|
||||
to sign the CLA if you haven't already.
|
||||
|
||||
Follow the steps given by the bot to sign the CLA. This will require you to log in with GitHub (we only request public
|
||||
information from your account) and to fill in a few additional details such as your name and email address. We will only
|
||||
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
|
||||
|
||||
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
|
||||
require you to sign again.
|
||||
|
||||
# Legal Terms and Agreement
|
||||
|
||||
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
|
||||
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
|
||||
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
|
||||
your own Contributions for any other purpose.
|
||||
|
||||
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
|
||||
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
|
||||
You reserve all right, title, and interest in and to Your Contributions.
|
||||
|
||||
1. Definitions.
|
||||
|
||||
```
|
||||
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
|
||||
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
|
||||
entities that control, are controlled by, or are under common control with that entity are considered
|
||||
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
|
||||
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
|
||||
percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
```
|
||||
```
|
||||
"Contribution" shall mean any original work of authorship, including any modifications or additions to
|
||||
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
|
||||
or documentation of, any of the products owned or managed by NetBird (the "Work").
|
||||
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
|
||||
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
|
||||
source code control systems, and issue tracking systems that are managed by, or on behalf of,
|
||||
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
||||
marked or otherwise designated in writing by You as "Not a Contribution."
|
||||
```
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
|
||||
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
|
||||
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
|
||||
perform, sublicense, and distribute Your Contributions and such derivative works.
|
||||
|
||||
|
||||
## 1 Preamble
|
||||
In order to clarify the IP Rights situation with regard to Contributions from any person or entity, NetBird
|
||||
must have a contributor license agreement on file to be signed by each Contributor, containing the license
|
||||
terms below. This license serves as protection for both the Contributor as well as NetBird and its software users;
|
||||
it does not change Contributor’s rights to use his/her own Contributions for any other purpose.
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
|
||||
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
||||
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
|
||||
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
|
||||
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
|
||||
such Contribution(s) was submitted. If any entity institutes patent litigation against You or any other entity (
|
||||
including a cross-claim or counterclaim in a lawsuit) alleging that your Contribution, or the Work to which you have
|
||||
contributed, constitutes direct or contributory patent infringement, then any patent licenses granted to that entity
|
||||
under this Agreement for that Contribution or Work shall terminate as of the date such litigation is filed.
|
||||
|
||||
## 2 Definitions
|
||||
2.1 “IP Rights” shall mean all industrial and intellectual property rights, whether registered or not registered, whether created by Contributor or acquired by Contributor from third parties, and similar rights, including (but not limited to) semiconductor property rights, design rights, copyrights (including in the form of database rights and rights to software), all neighbouring rights (Leistungsschutzrechte), trademarks, service marks, titles, internet domain names, trade names and other labelling rights, rights deriving from corresponding applications and registrations of such rights as well as any licenses (Nutzungsrechte) under and entitlements to any such intellectual and industrial property rights.
|
||||
|
||||
2.2 "Contribution" shall mean any original work of authorship, including any modifications or additions to an existing work, that is or previously has been intentionally Submitted by Contributor to NetBird for inclusion in, or documentation of any Work.
|
||||
4. You represent that you are legally entitled to grant the above license. If your employer(s) has rights to
|
||||
intellectual property that you create that includes your Contributions, you represent that you have received
|
||||
permission to make Contributions on behalf of that employer, that you will have received permission from your current
|
||||
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
|
||||
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
|
||||
with NetBird.
|
||||
|
||||
2.3 "Contributor" shall mean the copyright owner or legal entity authorized by the copyright owner that is concluding this Agreement with NetBird. For legal entities, the entity making a Contribution and all other entities that control, are controlled by, or are under common control with that entity are considered to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
2.4 "Submitted" shall mean any form of electronic, verbal, or written communication sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, NetBird for the purpose of discussing and improving the Work, but excluding communication that is marked or otherwise designated in writing by Contributor as "Not a Contribution".
|
||||
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
|
||||
others). You represent that Your Contribution submissions include complete details of any third-party license or
|
||||
other restriction (including, but not limited to, related patents and trademarks) of which you are personally aware
|
||||
and which are associated with any part of Your Contributions.
|
||||
|
||||
2.5 "Work" means any of the products owned or managed by NetBird, in particular, but not exclusively, software.
|
||||
|
||||
## 3 Licenses
|
||||
3.1 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable license to reproduce by any means and in any form, in whole or in part, permanently or temporarily, the Contributions (including loading, displaying, executing, transmitting or storing works for the purpose of executing and processing data or transferring them to video, audio and other data carriers), including the right to distribute, display and present such Contributions and make them available to the public (e.g. via the internet) and to transmit and display such Contributions by any means. The license also includes the right to modify, translate, adapt, edit and otherwise alter the Contributions and to use these results in the same manner as the original Contributions and derivative works. Except for licenses in patents acc. to Sec. 3, such license refers to any IP Rights in the Contributions and derivative works. The Contributor acknowledges that NetBird is not required to credit them by name for their Contribution and agrees to waive any moral rights associated with their Contribution in relation to NetBird or its sublicensees.
|
||||
6. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support.
|
||||
You may provide support for free, for a fee, or not at all. Unless required by applicable law or agreed to in
|
||||
writing, You provide Your Contributions on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
||||
express or implied, including, without limitation, any warranties or conditions of TITLE, NON- INFRINGEMENT,
|
||||
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
|
||||
3.2 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license in the Contributions to make, have made, use, sell, offer to sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by the Contributor which are necessarily infringed by Contributor‘s Contribution(s) alone or by combination of Contributor’s Contribution(s) with the Work to which such Contribution(s) was Submitted.
|
||||
|
||||
3.3 NetBird hereby accepts such licenses.
|
||||
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
|
||||
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
|
||||
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
|
||||
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
||||
|
||||
## 4 Contributor’s Representations
|
||||
4.1 Contributor represents that Contributor is legally entitled to grant the above license. If Contributor’s employer has IP Rights to Contributor’s Contributions, Contributor represent that he/she has received permission to make Contributions on behalf of such employer, that such employer has waived such IP Rights to the Contributions of Contributor to NetBird, or that such employer has executed a separate contributor license agreement with NetBird.
|
||||
|
||||
4.2 Contributor represents that any Contribution is his/her original creation.
|
||||
|
||||
4.3 Contributor represents to his/her best knowledge that any Contribution does not violate any third party IP Rights.
|
||||
|
||||
4.4 Contributor represents that any Contribution submission includes complete details of any third-party license or other restriction (including, but not limited to, related patents and trademarks) of which Contributor is personally aware and which are associated with any part of the Contribution.
|
||||
|
||||
4.5 The Contributor represents that their Contribution does not include any work distributed under a copyleft license.
|
||||
|
||||
## 5 Information obligation
|
||||
Contributor agrees to notify NetBird of any facts or circumstances of which Contributor become aware that would make these representations inaccurate in any respect.
|
||||
|
||||
## 6 Submission of Third-Party works
|
||||
Should Contributor wish to submit work that is not Contributor’s original creation, Contributor may submit it to NetBird separately from any Contribution, identifying the complete details of its source and of any license or other restriction (including, but not limited to, related patents, trademarks, and license agreements) of which Contributor are personally aware, and conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
||||
|
||||
## 7 No Consideration
|
||||
Unless compensation is mandatory under statutory law, no compensation for any license under this agreement shall be payable.
|
||||
|
||||
## 8 Final Provisions
|
||||
8.1 Laws. This Agreement is governed by the laws of the Federal Republic of Germany.
|
||||
|
||||
8.2 Venue. Place of jurisdiction shall, to the extent legally permissible, be Berlin, Germany.
|
||||
|
||||
8.3 Severability. If any provision in this agreement is unlawful, invalid or ineffective, it shall not affect the enforceability or effectiveness of the remainder of this agreement. The parties agree to replace any unlawful, invalid or ineffective provision with a provision that comes as close as possible to the commercial intent and purpose of the original provision. This section also applies accordingly to any gaps in the contract.
|
||||
|
||||
8.4 Variations. Any variations, amendments or supplements to this Agreement must be in writing. This also applies to any variation of this Section 8.4.
|
||||
|
||||
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
|
||||
representations inaccurate in any respect.
|
||||
|
||||
64
README.md
64
README.md
@@ -1,6 +1,4 @@
|
||||
<div align="center">
|
||||
<br/>
|
||||
<br/>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
@@ -12,7 +10,7 @@
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://docs.netbird.io/slack-url">
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<br>
|
||||
@@ -29,14 +27,10 @@
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<a href="https://github.com/netbirdio/kubernetes-operator">
|
||||
New: NetBird Kubernetes Operator
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<br>
|
||||
@@ -57,16 +51,16 @@
|
||||
|
||||
### Key features
|
||||
|
||||
| Connectivity | Management | Security | Automation| Platforms |
|
||||
|----|----|----|----|----|
|
||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
|
||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
@@ -134,37 +128,3 @@ We use open-source technologies like [WireGuard®](https://www.wireguard.com/),
|
||||
### Legal
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
## Configuration Management
|
||||
|
||||
Netbird now supports direct configuration management via CLI commands:
|
||||
|
||||
- You can use `netbird set` as a regular user if the daemon is running; it will securely update the config via the daemon.
|
||||
- If the daemon is not running, you need write access to the config file (typically requires root).
|
||||
|
||||
### Set a configuration value
|
||||
|
||||
```
|
||||
netbird set <setting> <value>
|
||||
# or using environment variables
|
||||
NB_INTERFACE_NAME=utun5 netbird set interface-name
|
||||
```
|
||||
|
||||
### Get a configuration value
|
||||
|
||||
```
|
||||
netbird get <setting>
|
||||
# or using environment variables
|
||||
NB_INTERFACE_NAME=utun5 netbird get interface-name
|
||||
```
|
||||
|
||||
### Show all configuration values
|
||||
|
||||
```
|
||||
netbird show
|
||||
```
|
||||
|
||||
- All settings support environment variable overrides: `NB_<SETTING>` or `WT_<SETTING>` (e.g. `NB_ENABLE_ROSENPASS=true`).
|
||||
- Supported settings: management-url, admin-url, interface-name, external-ip-map, extra-iface-blacklist, dns-resolver-address, extra-dns-labels, preshared-key, enable-rosenpass, rosenpass-permissive, allow-server-ssh, network-monitor, disable-auto-connect, disable-client-routes, disable-server-routes, disable-dns, disable-firewall, block-lan-access, block-inbound, enable-lazy-connection, wireguard-port, dns-router-interval.
|
||||
|
||||
See `netbird set --help`, `netbird get --help`, and `netbird show --help` for more details.
|
||||
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
FROM alpine:3.21.3
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
|
||||
|
||||
ARG NETBIRD_BINARY=netbird
|
||||
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
|
||||
|
||||
FROM alpine:3.21.0
|
||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||
ENV NB_FOREGROUND_MODE=true
|
||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
@@ -1,7 +1,6 @@
|
||||
FROM alpine:3.21.0
|
||||
|
||||
ARG NETBIRD_BINARY=netbird
|
||||
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
|
||||
COPY netbird /usr/local/bin/netbird
|
||||
|
||||
RUN apk add --no-cache ca-certificates \
|
||||
&& adduser -D -h /var/lib/netbird netbird
|
||||
|
||||
@@ -59,8 +59,6 @@ type Client struct {
|
||||
deviceName string
|
||||
uiVersion string
|
||||
networkChangeListener listener.NetworkChangeListener
|
||||
|
||||
connectClient *internal.ConnectClient
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
@@ -108,8 +106,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
@@ -134,8 +132,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -176,53 +174,6 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
return &PeerInfoArray{items: peerInfos}
|
||||
}
|
||||
|
||||
func (c *Client) Networks() *NetworkArray {
|
||||
if c.connectClient == nil {
|
||||
log.Error("not connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := c.connectClient.Engine()
|
||||
if engine == nil {
|
||||
log.Error("could not get engine")
|
||||
return nil
|
||||
}
|
||||
|
||||
routeManager := engine.GetRouteManager()
|
||||
if routeManager == nil {
|
||||
log.Error("could not get route manager")
|
||||
return nil
|
||||
}
|
||||
|
||||
networkArray := &NetworkArray{
|
||||
items: make([]Network, 0),
|
||||
}
|
||||
|
||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||
if len(routes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if routes[0].IsDynamic() {
|
||||
continue
|
||||
}
|
||||
|
||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||
continue
|
||||
}
|
||||
network := Network{
|
||||
Name: string(id),
|
||||
Network: routes[0].Network.String(),
|
||||
Peer: peer.FQDN,
|
||||
Status: peer.ConnStatus.String(),
|
||||
}
|
||||
networkArray.Add(network)
|
||||
}
|
||||
return networkArray
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||
dnsServer, err := dns.GetServerDns()
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
type Network struct {
|
||||
Name string
|
||||
Network string
|
||||
Peer string
|
||||
Status string
|
||||
}
|
||||
|
||||
type NetworkArray struct {
|
||||
items []Network
|
||||
}
|
||||
|
||||
func (array *NetworkArray) Add(s Network) *NetworkArray {
|
||||
array.items = append(array.items, s)
|
||||
return array
|
||||
}
|
||||
|
||||
func (array *NetworkArray) Get(i int) *Network {
|
||||
return &array.items[i]
|
||||
}
|
||||
|
||||
func (array *NetworkArray) Size() int {
|
||||
return len(array.items)
|
||||
}
|
||||
@@ -7,23 +7,30 @@ type PeerInfo struct {
|
||||
ConnStatus string // Todo replace to enum
|
||||
}
|
||||
|
||||
// PeerInfoArray is a wrapper of []PeerInfo
|
||||
// PeerInfoCollection made for Java layer to get non default types as collection
|
||||
type PeerInfoCollection interface {
|
||||
Add(s string) PeerInfoCollection
|
||||
Get(i int) string
|
||||
Size() int
|
||||
}
|
||||
|
||||
// PeerInfoArray is the implementation of the PeerInfoCollection
|
||||
type PeerInfoArray struct {
|
||||
items []PeerInfo
|
||||
}
|
||||
|
||||
// Add new PeerInfo to the collection
|
||||
func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray {
|
||||
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
|
||||
array.items = append(array.items, s)
|
||||
return array
|
||||
}
|
||||
|
||||
// Get return an element of the collection
|
||||
func (array *PeerInfoArray) Get(i int) *PeerInfo {
|
||||
func (array PeerInfoArray) Get(i int) *PeerInfo {
|
||||
return &array.items[i]
|
||||
}
|
||||
|
||||
// Size return with the size of the collection
|
||||
func (array *PeerInfoArray) Size() int {
|
||||
func (array PeerInfoArray) Size() int {
|
||||
return len(array.items)
|
||||
}
|
||||
|
||||
@@ -4,12 +4,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
// Preferences exports a subset of the internal config for gomobile
|
||||
// Preferences export a subset of the internal config for gomobile
|
||||
type Preferences struct {
|
||||
configInput internal.ConfigInput
|
||||
}
|
||||
|
||||
// NewPreferences creates a new Preferences instance
|
||||
// NewPreferences create new Preferences instance
|
||||
func NewPreferences(configPath string) *Preferences {
|
||||
ci := internal.ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
|
||||
return &Preferences{ci}
|
||||
}
|
||||
|
||||
// GetManagementURL reads URL from config file
|
||||
// GetManagementURL read url from config file
|
||||
func (p *Preferences) GetManagementURL() (string, error) {
|
||||
if p.configInput.ManagementURL != "" {
|
||||
return p.configInput.ManagementURL, nil
|
||||
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
|
||||
return cfg.ManagementURL.String(), err
|
||||
}
|
||||
|
||||
// SetManagementURL stores the given URL and waits for commit
|
||||
// SetManagementURL store the given url and wait for commit
|
||||
func (p *Preferences) SetManagementURL(url string) {
|
||||
p.configInput.ManagementURL = url
|
||||
}
|
||||
|
||||
// GetAdminURL reads URL from config file
|
||||
// GetAdminURL read url from config file
|
||||
func (p *Preferences) GetAdminURL() (string, error) {
|
||||
if p.configInput.AdminURL != "" {
|
||||
return p.configInput.AdminURL, nil
|
||||
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
|
||||
return cfg.AdminURL.String(), err
|
||||
}
|
||||
|
||||
// SetAdminURL stores the given URL and waits for commit
|
||||
// SetAdminURL store the given url and wait for commit
|
||||
func (p *Preferences) SetAdminURL(url string) {
|
||||
p.configInput.AdminURL = url
|
||||
}
|
||||
|
||||
// GetPreSharedKey reads pre-shared key from config file
|
||||
// GetPreSharedKey read preshared key from config file
|
||||
func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||
if p.configInput.PreSharedKey != nil {
|
||||
return *p.configInput.PreSharedKey, nil
|
||||
@@ -66,160 +66,12 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||
return cfg.PreSharedKey, err
|
||||
}
|
||||
|
||||
// SetPreSharedKey stores the given key and waits for commit
|
||||
// SetPreSharedKey store the given key and wait for commit
|
||||
func (p *Preferences) SetPreSharedKey(key string) {
|
||||
p.configInput.PreSharedKey = &key
|
||||
}
|
||||
|
||||
// SetRosenpassEnabled stores whether Rosenpass is enabled
|
||||
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
|
||||
p.configInput.RosenpassEnabled = &enabled
|
||||
}
|
||||
|
||||
// GetRosenpassEnabled reads Rosenpass enabled status from config file
|
||||
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
||||
if p.configInput.RosenpassEnabled != nil {
|
||||
return *p.configInput.RosenpassEnabled, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.RosenpassEnabled, err
|
||||
}
|
||||
|
||||
// SetRosenpassPermissive stores the given permissive setting and waits for commit
|
||||
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
|
||||
p.configInput.RosenpassPermissive = &permissive
|
||||
}
|
||||
|
||||
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
|
||||
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||
if p.configInput.RosenpassPermissive != nil {
|
||||
return *p.configInput.RosenpassPermissive, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.RosenpassPermissive, err
|
||||
}
|
||||
|
||||
// GetDisableClientRoutes reads disable client routes setting from config file
|
||||
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
|
||||
if p.configInput.DisableClientRoutes != nil {
|
||||
return *p.configInput.DisableClientRoutes, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableClientRoutes, err
|
||||
}
|
||||
|
||||
// SetDisableClientRoutes stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableClientRoutes(disable bool) {
|
||||
p.configInput.DisableClientRoutes = &disable
|
||||
}
|
||||
|
||||
// GetDisableServerRoutes reads disable server routes setting from config file
|
||||
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
|
||||
if p.configInput.DisableServerRoutes != nil {
|
||||
return *p.configInput.DisableServerRoutes, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableServerRoutes, err
|
||||
}
|
||||
|
||||
// SetDisableServerRoutes stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableServerRoutes(disable bool) {
|
||||
p.configInput.DisableServerRoutes = &disable
|
||||
}
|
||||
|
||||
// GetDisableDNS reads disable DNS setting from config file
|
||||
func (p *Preferences) GetDisableDNS() (bool, error) {
|
||||
if p.configInput.DisableDNS != nil {
|
||||
return *p.configInput.DisableDNS, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableDNS, err
|
||||
}
|
||||
|
||||
// SetDisableDNS stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableDNS(disable bool) {
|
||||
p.configInput.DisableDNS = &disable
|
||||
}
|
||||
|
||||
// GetDisableFirewall reads disable firewall setting from config file
|
||||
func (p *Preferences) GetDisableFirewall() (bool, error) {
|
||||
if p.configInput.DisableFirewall != nil {
|
||||
return *p.configInput.DisableFirewall, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableFirewall, err
|
||||
}
|
||||
|
||||
// SetDisableFirewall stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableFirewall(disable bool) {
|
||||
p.configInput.DisableFirewall = &disable
|
||||
}
|
||||
|
||||
// GetServerSSHAllowed reads server SSH allowed setting from config file
|
||||
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
|
||||
if p.configInput.ServerSSHAllowed != nil {
|
||||
return *p.configInput.ServerSSHAllowed, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.ServerSSHAllowed == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.ServerSSHAllowed, err
|
||||
}
|
||||
|
||||
// SetServerSSHAllowed stores the given value and waits for commit
|
||||
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
||||
p.configInput.ServerSSHAllowed = &allowed
|
||||
}
|
||||
|
||||
// GetBlockInbound reads block inbound setting from config file
|
||||
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||
if p.configInput.BlockInbound != nil {
|
||||
return *p.configInput.BlockInbound, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.BlockInbound, err
|
||||
}
|
||||
|
||||
// SetBlockInbound stores the given value and waits for commit
|
||||
func (p *Preferences) SetBlockInbound(block bool) {
|
||||
p.configInput.BlockInbound = &block
|
||||
}
|
||||
|
||||
// Commit writes out the changes to the config file
|
||||
// Commit write out the changes into config file
|
||||
func (p *Preferences) Commit() error {
|
||||
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||
return err
|
||||
|
||||
@@ -26,7 +26,7 @@ type Anonymizer struct {
|
||||
}
|
||||
|
||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||
// 198.51.100.0, 100::
|
||||
// 192.51.100.0, 100::
|
||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
||||
}
|
||||
|
||||
@@ -69,22 +69,6 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
|
||||
return a.ipAnonymizer[ip]
|
||||
}
|
||||
|
||||
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
|
||||
// Convert IP to netip.Addr
|
||||
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||
if !ok {
|
||||
return addr
|
||||
}
|
||||
|
||||
anonIP := a.AnonymizeIP(ip)
|
||||
|
||||
return net.UDPAddr{
|
||||
IP: anonIP.AsSlice(),
|
||||
Port: addr.Port,
|
||||
Zone: addr.Zone,
|
||||
}
|
||||
}
|
||||
|
||||
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
||||
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
||||
|
||||
@@ -11,12 +11,8 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
const errCloseConnection = "Failed to close connection: %v"
|
||||
@@ -87,27 +83,16 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
request := &proto.DebugBundleRequest{
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||
Status: getStatusOutput(cmd),
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
}
|
||||
if debugUploadBundle {
|
||||
request.UploadURL = debugUploadBundleURL
|
||||
}
|
||||
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||
|
||||
if resp.GetUploadFailureReason() != "" {
|
||||
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||
}
|
||||
|
||||
if debugUploadBundle {
|
||||
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||
}
|
||||
cmd.Println(resp.GetPath())
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -211,7 +196,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
|
||||
|
||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||
return waitErr
|
||||
@@ -221,20 +206,24 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Creating debug bundle...")
|
||||
|
||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||
request := &proto.DebugBundleRequest{
|
||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
|
||||
|
||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||
Anonymize: anonymizeFlag,
|
||||
Status: statusOutput,
|
||||
SystemInfo: debugSystemInfoFlag,
|
||||
}
|
||||
if debugUploadBundle {
|
||||
request.UploadURL = debugUploadBundleURL
|
||||
}
|
||||
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
// Disable network map persistence after creating the debug bundle
|
||||
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||
Enabled: false,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if stateWasDown {
|
||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||
@@ -249,15 +238,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||
}
|
||||
|
||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||
|
||||
if resp.GetUploadFailureReason() != "" {
|
||||
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||
}
|
||||
|
||||
if debugUploadBundle {
|
||||
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||
}
|
||||
cmd.Println(resp.GetPath())
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -290,15 +271,13 @@ func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
func getStatusOutput(cmd *cobra.Command) string {
|
||||
var statusOutputString string
|
||||
statusResp, err := getStatus(cmd.Context())
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
|
||||
)
|
||||
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
@@ -344,34 +323,3 @@ func formatDuration(d time.Duration) string {
|
||||
s := d / time.Second
|
||||
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
||||
}
|
||||
|
||||
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
|
||||
var networkMap *mgmProto.NetworkMap
|
||||
var err error
|
||||
|
||||
if connectClient != nil {
|
||||
networkMap, err = connectClient.GetLatestNetworkMap()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get latest network map: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(
|
||||
debug.GeneratorDependencies{
|
||||
InternalConfig: config,
|
||||
StatusRecorder: recorder,
|
||||
NetworkMap: networkMap,
|
||||
LogFile: logFilePath,
|
||||
},
|
||||
debug.BundleConfig{
|
||||
IncludeSystemInfo: true,
|
||||
},
|
||||
)
|
||||
|
||||
path, err := bundleGenerator.Generate()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to generate debug bundle: %v", err)
|
||||
return
|
||||
}
|
||||
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
|
||||
}
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
func SetupDebugHandler(
|
||||
ctx context.Context,
|
||||
config *internal.Config,
|
||||
recorder *peer.Status,
|
||||
connectClient *internal.ConnectClient,
|
||||
logFilePath string,
|
||||
) {
|
||||
usr1Ch := make(chan os.Signal, 1)
|
||||
|
||||
signal.Notify(usr1Ch, syscall.SIGUSR1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-usr1Ch:
|
||||
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
|
||||
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -1,126 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
|
||||
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
|
||||
|
||||
waitTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
|
||||
// Example usage with PowerShell:
|
||||
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
|
||||
// $evt.Set()
|
||||
// $evt.Close()
|
||||
func SetupDebugHandler(
|
||||
ctx context.Context,
|
||||
config *internal.Config,
|
||||
recorder *peer.Status,
|
||||
connectClient *internal.ConnectClient,
|
||||
logFilePath string,
|
||||
) {
|
||||
env := os.Getenv(envListenEvent)
|
||||
if env == "" {
|
||||
return
|
||||
}
|
||||
|
||||
listenEvent, err := strconv.ParseBool(env)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
|
||||
return
|
||||
}
|
||||
if !listenEvent {
|
||||
return
|
||||
}
|
||||
|
||||
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: restrict access by ACL
|
||||
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
|
||||
if err != nil {
|
||||
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
|
||||
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
|
||||
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
|
||||
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
|
||||
} else {
|
||||
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if eventHandle == windows.InvalidHandle {
|
||||
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
|
||||
|
||||
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
|
||||
}
|
||||
|
||||
func waitForEvent(
|
||||
ctx context.Context,
|
||||
config *internal.Config,
|
||||
recorder *peer.Status,
|
||||
connectClient *internal.ConnectClient,
|
||||
logFilePath string,
|
||||
eventHandle windows.Handle,
|
||||
) {
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(eventHandle); err != nil {
|
||||
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
|
||||
|
||||
switch status {
|
||||
case windows.WAIT_OBJECT_0:
|
||||
log.Info("Received signal on debug event. Triggering debug bundle generation.")
|
||||
|
||||
// reset the event so it can be triggered again later (manual reset == 1)
|
||||
if err := windows.ResetEvent(eventHandle); err != nil {
|
||||
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
|
||||
}
|
||||
|
||||
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||
case uint32(windows.WAIT_TIMEOUT):
|
||||
|
||||
default:
|
||||
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var forwardingRulesCmd = &cobra.Command{
|
||||
Use: "forwarding",
|
||||
Short: "List forwarding rules",
|
||||
Long: `Commands to list forwarding rules.`,
|
||||
}
|
||||
|
||||
var forwardingRulesListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List forwarding rules",
|
||||
Example: " netbird forwarding list",
|
||||
Long: "Commands to list forwarding rules.",
|
||||
RunE: listForwardingRules,
|
||||
}
|
||||
|
||||
func listForwardingRules(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.ForwardingRules(cmd.Context(), &proto.EmptyRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
if len(resp.GetRules()) == 0 {
|
||||
cmd.Println("No forwarding rules available.")
|
||||
return nil
|
||||
}
|
||||
|
||||
printForwardingRules(cmd, resp.GetRules())
|
||||
return nil
|
||||
}
|
||||
|
||||
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
|
||||
cmd.Println("Available forwarding rules:")
|
||||
|
||||
// Sort rules by translated address
|
||||
sort.Slice(rules, func(i, j int) bool {
|
||||
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
|
||||
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
|
||||
}
|
||||
if rules[i].GetProtocol() != rules[j].GetProtocol() {
|
||||
return rules[i].GetProtocol() < rules[j].GetProtocol()
|
||||
}
|
||||
|
||||
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
|
||||
})
|
||||
|
||||
var lastIP string
|
||||
for _, rule := range rules {
|
||||
dPort := portToString(rule.GetDestinationPort())
|
||||
tPort := portToString(rule.GetTranslatedPort())
|
||||
if lastIP != rule.GetTranslatedAddress() {
|
||||
lastIP = rule.GetTranslatedAddress()
|
||||
cmd.Printf("\nTranslated peer: %s\n", rule.GetTranslatedHostname())
|
||||
}
|
||||
|
||||
cmd.Printf(" Local %s/%s to %s:%s\n", rule.GetProtocol(), dPort, rule.GetTranslatedAddress(), tPort)
|
||||
}
|
||||
}
|
||||
|
||||
func getFirstPort(portInfo *proto.PortInfo) int {
|
||||
switch v := portInfo.PortSelection.(type) {
|
||||
case *proto.PortInfo_Port:
|
||||
return int(v.Port)
|
||||
case *proto.PortInfo_Range_:
|
||||
return int(v.Range.GetStart())
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func portToString(translatedPort *proto.PortInfo) string {
|
||||
switch v := translatedPort.PortSelection.(type) {
|
||||
case *proto.PortInfo_Port:
|
||||
return fmt.Sprintf("%d", v.Port)
|
||||
case *proto.PortInfo_Range_:
|
||||
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
|
||||
default:
|
||||
return "No port specified"
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -20,10 +19,6 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
}
|
||||
|
||||
var loginCmd = &cobra.Command{
|
||||
Use: "login",
|
||||
Short: "login to the Netbird Management Service (first run)",
|
||||
@@ -56,9 +51,6 @@ var loginCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
// update host's static platform and system information
|
||||
system.UpdateStaticInfo()
|
||||
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
@@ -93,17 +85,11 @@ var loginCmd = &cobra.Command{
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
var dnsLabelsReq []string
|
||||
if dnsLabelsValidated != nil {
|
||||
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
DnsLabels: dnsLabelsReq,
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
@@ -135,7 +121,7 @@ var loginCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
@@ -196,7 +182,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -206,7 +192,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||
@@ -220,34 +206,23 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
return &tokenInfo, nil
|
||||
}
|
||||
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
||||
var codeMsg string
|
||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||
}
|
||||
|
||||
if noBrowser {
|
||||
cmd.Println("Use this URL to log in:\n\n" + verificationURIComplete + " " + codeMsg)
|
||||
} else {
|
||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
}
|
||||
|
||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||
verificationURIComplete + " " + codeMsg)
|
||||
cmd.Println("")
|
||||
|
||||
if !noBrowser {
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
}
|
||||
|
||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
}
|
||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isLinuxRunningDesktop() bool {
|
||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||
}
|
||||
|
||||
@@ -22,27 +22,23 @@ import (
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
enableRosenpassFlag = "enable-rosenpass"
|
||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||
preSharedKeyFlag = "preshared-key"
|
||||
interfaceNameFlag = "interface-name"
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||
uploadBundle = "upload-bundle"
|
||||
uploadBundleURL = "upload-bundle-url"
|
||||
externalIPMapFlag = "external-ip-map"
|
||||
dnsResolverAddress = "dns-resolver-address"
|
||||
enableRosenpassFlag = "enable-rosenpass"
|
||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||
preSharedKeyFlag = "preshared-key"
|
||||
interfaceNameFlag = "interface-name"
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -78,9 +74,7 @@ var (
|
||||
anonymizeFlag bool
|
||||
debugSystemInfoFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
debugUploadBundle bool
|
||||
debugUploadBundleURL string
|
||||
lazyConnEnabled bool
|
||||
blockLANAccess bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
@@ -88,30 +82,6 @@ var (
|
||||
Long: "",
|
||||
SilenceUsage: true,
|
||||
}
|
||||
|
||||
getCmd = &cobra.Command{
|
||||
Use: "get <setting>",
|
||||
Short: "Get a configuration value from the config file",
|
||||
Long: `Get a configuration value from the Netbird config file. You can also use NB_<SETTING> or WT_<SETTING> environment variables to override the value (same as 'set').`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: getFunc,
|
||||
}
|
||||
|
||||
showCmd = &cobra.Command{
|
||||
Use: "show",
|
||||
Short: "Show all configuration values",
|
||||
Long: `Show all configuration values from the Netbird config file, with environment variable overrides if present.`,
|
||||
Args: cobra.NoArgs,
|
||||
RunE: showFunc,
|
||||
}
|
||||
|
||||
reloadCmd = &cobra.Command{
|
||||
Use: "reload",
|
||||
Short: "Reload the configuration in the daemon (daemon mode)",
|
||||
Long: `Reload the configuration from disk in the running daemon. Use after 'set' to apply changes without restarting the service.`,
|
||||
Args: cobra.NoArgs,
|
||||
RunE: reloadFunc,
|
||||
}
|
||||
)
|
||||
|
||||
// Execute executes the root command.
|
||||
@@ -175,11 +145,7 @@ func init() {
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
rootCmd.AddCommand(networksCMD)
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
rootCmd.AddCommand(getCmd)
|
||||
rootCmd.AddCommand(showCmd)
|
||||
rootCmd.AddCommand(reloadCmd)
|
||||
|
||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
||||
@@ -187,8 +153,6 @@ func init() {
|
||||
networksCMD.AddCommand(routesListCmd)
|
||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
|
||||
forwardingRulesCmd.AddCommand(forwardingRulesListCmd)
|
||||
|
||||
debugCmd.AddCommand(debugBundleCmd)
|
||||
debugCmd.AddCommand(logCmd)
|
||||
logCmd.AddCommand(logLevelCmd)
|
||||
@@ -212,11 +176,8 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
|
||||
}
|
||||
|
||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||
@@ -436,167 +397,3 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func getFunc(cmd *cobra.Command, args []string) error {
|
||||
setting := args[0]
|
||||
upper := strings.ToUpper(strings.ReplaceAll(setting, "-", "_"))
|
||||
if v, ok := os.LookupEnv("NB_" + upper); ok {
|
||||
cmd.Println(v)
|
||||
return nil
|
||||
} else if v, ok := os.LookupEnv("WT_" + upper); ok {
|
||||
cmd.Println(v)
|
||||
return nil
|
||||
}
|
||||
config, err := internal.ReadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config: %v", err)
|
||||
}
|
||||
switch setting {
|
||||
case "management-url":
|
||||
cmd.Println(config.ManagementURL.String())
|
||||
case "admin-url":
|
||||
cmd.Println(config.AdminURL.String())
|
||||
case "interface-name":
|
||||
cmd.Println(config.WgIface)
|
||||
case "external-ip-map":
|
||||
cmd.Println(strings.Join(config.NATExternalIPs, ","))
|
||||
case "extra-iface-blacklist":
|
||||
cmd.Println(strings.Join(config.IFaceBlackList, ","))
|
||||
case "dns-resolver-address":
|
||||
cmd.Println(config.CustomDNSAddress)
|
||||
case "extra-dns-labels":
|
||||
cmd.Println(config.DNSLabels.SafeString())
|
||||
case "preshared-key":
|
||||
cmd.Println(config.PreSharedKey)
|
||||
case "enable-rosenpass":
|
||||
cmd.Println(config.RosenpassEnabled)
|
||||
case "rosenpass-permissive":
|
||||
cmd.Println(config.RosenpassPermissive)
|
||||
case "allow-server-ssh":
|
||||
if config.ServerSSHAllowed != nil {
|
||||
cmd.Println(*config.ServerSSHAllowed)
|
||||
} else {
|
||||
cmd.Println(false)
|
||||
}
|
||||
case "network-monitor":
|
||||
if config.NetworkMonitor != nil {
|
||||
cmd.Println(*config.NetworkMonitor)
|
||||
} else {
|
||||
cmd.Println(false)
|
||||
}
|
||||
case "disable-auto-connect":
|
||||
cmd.Println(config.DisableAutoConnect)
|
||||
case "disable-client-routes":
|
||||
cmd.Println(config.DisableClientRoutes)
|
||||
case "disable-server-routes":
|
||||
cmd.Println(config.DisableServerRoutes)
|
||||
case "disable-dns":
|
||||
cmd.Println(config.DisableDNS)
|
||||
case "disable-firewall":
|
||||
cmd.Println(config.DisableFirewall)
|
||||
case "block-lan-access":
|
||||
cmd.Println(config.BlockLANAccess)
|
||||
case "block-inbound":
|
||||
cmd.Println(config.BlockInbound)
|
||||
case "enable-lazy-connection":
|
||||
cmd.Println(config.LazyConnectionEnabled)
|
||||
case "wireguard-port":
|
||||
cmd.Println(config.WgPort)
|
||||
case "dns-router-interval":
|
||||
cmd.Println(config.DNSRouteInterval)
|
||||
default:
|
||||
return fmt.Errorf("unknown setting: %s", setting)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func showFunc(cmd *cobra.Command, args []string) error {
|
||||
config, err := internal.ReadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config: %v", err)
|
||||
}
|
||||
settings := []string{
|
||||
"management-url", "admin-url", "interface-name", "external-ip-map", "extra-iface-blacklist", "dns-resolver-address", "extra-dns-labels", "preshared-key", "enable-rosenpass", "rosenpass-permissive", "allow-server-ssh", "network-monitor", "disable-auto-connect", "disable-client-routes", "disable-server-routes", "disable-dns", "disable-firewall", "block-lan-access", "block-inbound", "enable-lazy-connection", "wireguard-port", "dns-router-interval",
|
||||
}
|
||||
for _, setting := range settings {
|
||||
upper := strings.ToUpper(strings.ReplaceAll(setting, "-", "_"))
|
||||
var val string
|
||||
if v, ok := os.LookupEnv("NB_" + upper); ok {
|
||||
val = v + " (from NB_ env)"
|
||||
} else if v, ok := os.LookupEnv("WT_" + upper); ok {
|
||||
val = v + " (from WT_ env)"
|
||||
} else {
|
||||
switch setting {
|
||||
case "management-url":
|
||||
val = config.ManagementURL.String()
|
||||
case "admin-url":
|
||||
val = config.AdminURL.String()
|
||||
case "interface-name":
|
||||
val = config.WgIface
|
||||
case "external-ip-map":
|
||||
val = strings.Join(config.NATExternalIPs, ",")
|
||||
case "extra-iface-blacklist":
|
||||
val = strings.Join(config.IFaceBlackList, ",")
|
||||
case "dns-resolver-address":
|
||||
val = config.CustomDNSAddress
|
||||
case "extra-dns-labels":
|
||||
val = config.DNSLabels.SafeString()
|
||||
case "preshared-key":
|
||||
val = config.PreSharedKey
|
||||
case "enable-rosenpass":
|
||||
val = fmt.Sprintf("%v", config.RosenpassEnabled)
|
||||
case "rosenpass-permissive":
|
||||
val = fmt.Sprintf("%v", config.RosenpassPermissive)
|
||||
case "allow-server-ssh":
|
||||
if config.ServerSSHAllowed != nil {
|
||||
val = fmt.Sprintf("%v", *config.ServerSSHAllowed)
|
||||
} else {
|
||||
val = "false"
|
||||
}
|
||||
case "network-monitor":
|
||||
if config.NetworkMonitor != nil {
|
||||
val = fmt.Sprintf("%v", *config.NetworkMonitor)
|
||||
} else {
|
||||
val = "false"
|
||||
}
|
||||
case "disable-auto-connect":
|
||||
val = fmt.Sprintf("%v", config.DisableAutoConnect)
|
||||
case "disable-client-routes":
|
||||
val = fmt.Sprintf("%v", config.DisableClientRoutes)
|
||||
case "disable-server-routes":
|
||||
val = fmt.Sprintf("%v", config.DisableServerRoutes)
|
||||
case "disable-dns":
|
||||
val = fmt.Sprintf("%v", config.DisableDNS)
|
||||
case "disable-firewall":
|
||||
val = fmt.Sprintf("%v", config.DisableFirewall)
|
||||
case "block-lan-access":
|
||||
val = fmt.Sprintf("%v", config.BlockLANAccess)
|
||||
case "block-inbound":
|
||||
val = fmt.Sprintf("%v", config.BlockInbound)
|
||||
case "enable-lazy-connection":
|
||||
val = fmt.Sprintf("%v", config.LazyConnectionEnabled)
|
||||
case "wireguard-port":
|
||||
val = fmt.Sprintf("%d", config.WgPort)
|
||||
case "dns-router-interval":
|
||||
val = config.DNSRouteInterval.String()
|
||||
}
|
||||
}
|
||||
cmd.Printf("%-22s: %s\n", setting, val)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func reloadFunc(cmd *cobra.Command, args []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
_, err = client.ReloadConfig(cmd.Context(), &proto.ReloadConfigRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to reload config in daemon: %v", err)
|
||||
}
|
||||
cmd.Println("Configuration reloaded in daemon.")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -28,19 +27,12 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
||||
}
|
||||
|
||||
func newSVCConfig() *service.Config {
|
||||
config := &service.Config{
|
||||
return &service.Config{
|
||||
Name: serviceName,
|
||||
DisplayName: "Netbird",
|
||||
Description: "Netbird mesh network client",
|
||||
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
||||
Option: make(service.KeyValue),
|
||||
EnvVars: make(map[string]string),
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
config.EnvVars["SYSTEMD_UNIT"] = serviceName
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
|
||||
|
||||
@@ -16,17 +16,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func (p *program) Start(svc service.Service) error {
|
||||
// Start should not block. Do the actual work async.
|
||||
log.Info("starting Netbird service") //nolint
|
||||
|
||||
// Collect static system and platform information
|
||||
system.UpdateStaticInfo()
|
||||
|
||||
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
||||
p.serv = grpc.NewServer()
|
||||
|
||||
@@ -120,7 +115,6 @@ var runCmd = &cobra.Command{
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
SetupDebugHandler(ctx, nil, nil, nil, logFile)
|
||||
|
||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||
if err != nil {
|
||||
|
||||
@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
||||
}
|
||||
|
||||
if logFile != "" {
|
||||
if logFile != "console" {
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,475 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
osuser "os/user"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var setCmd = &cobra.Command{
|
||||
Use: "set <setting> <value>",
|
||||
Short: "Set a configuration value without running up",
|
||||
Long: `Set a configuration value in the Netbird config file without running 'up'.
|
||||
|
||||
You can also set values via environment variables NB_<SETTING> or WT_<SETTING> (e.g. NB_INTERFACE_NAME=utun5 netbird set interface-name).
|
||||
|
||||
Supported settings:
|
||||
management-url (string) e.g. https://api.netbird.io:443
|
||||
admin-url (string) e.g. https://app.netbird.io:443
|
||||
interface-name (string) e.g. utun5
|
||||
external-ip-map (list) comma-separated, e.g. 12.34.56.78,12.34.56.79/eth0
|
||||
extra-iface-blacklist (list) comma-separated, e.g. eth1,eth2
|
||||
dns-resolver-address (string) e.g. 127.0.0.1:5053
|
||||
extra-dns-labels (list) comma-separated, e.g. vpc1,mgmt1
|
||||
preshared-key (string)
|
||||
enable-rosenpass (bool) true/false
|
||||
rosenpass-permissive (bool) true/false
|
||||
allow-server-ssh (bool) true/false
|
||||
network-monitor (bool) true/false
|
||||
disable-auto-connect (bool) true/false
|
||||
disable-client-routes (bool) true/false
|
||||
disable-server-routes (bool) true/false
|
||||
disable-dns (bool) true/false
|
||||
disable-firewall (bool) true/false
|
||||
block-lan-access (bool) true/false
|
||||
block-inbound (bool) true/false
|
||||
enable-lazy-connection (bool) true/false
|
||||
wireguard-port (int) e.g. 51820
|
||||
dns-router-interval (duration) e.g. 1m, 30s
|
||||
|
||||
Examples:
|
||||
NB_INTERFACE_NAME=utun5 netbird set interface-name
|
||||
netbird set wireguard-port 51820
|
||||
netbird set external-ip-map 12.34.56.78,12.34.56.79/eth0
|
||||
netbird set enable-rosenpass true
|
||||
netbird set dns-router-interval 2m
|
||||
netbird set extra-dns-labels vpc1,mgmt1
|
||||
netbird set disable-firewall true
|
||||
`,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: setFunc,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(setCmd)
|
||||
}
|
||||
|
||||
func setFunc(cmd *cobra.Command, args []string) error {
|
||||
setting := args[0]
|
||||
var value string
|
||||
|
||||
// Check environment variables first
|
||||
upper := strings.ToUpper(strings.ReplaceAll(setting, "-", "_"))
|
||||
if v, ok := os.LookupEnv("NB_" + upper); ok {
|
||||
value = v
|
||||
} else if v, ok := os.LookupEnv("WT_" + upper); ok {
|
||||
value = v
|
||||
} else {
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("missing value for setting %s", setting)
|
||||
}
|
||||
value = args[1]
|
||||
}
|
||||
|
||||
// If not root, try to use the daemon (only if cmd is not nil)
|
||||
if cmd != nil {
|
||||
if u, err := osuser.Current(); err == nil && u.Uid != "0" {
|
||||
conn, err := getClient(cmd)
|
||||
if err == nil {
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
_, err = client.SetConfigValue(cmd.Context(), &proto.SetConfigValueRequest{Setting: setting, Value: value})
|
||||
if err == nil {
|
||||
if cmd != nil {
|
||||
cmd.Println("Configuration updated via daemon.")
|
||||
} else {
|
||||
fmt.Println("Configuration updated via daemon.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if s, ok := status.FromError(err); ok {
|
||||
return fmt.Errorf("daemon error: %v", s.Message())
|
||||
}
|
||||
return fmt.Errorf("failed to update config via daemon: %v", err)
|
||||
}
|
||||
// else: fall back to direct file write
|
||||
}
|
||||
}
|
||||
|
||||
switch setting {
|
||||
case "management-url":
|
||||
input := internal.ConfigInput{ConfigPath: configPath, ManagementURL: value}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set management-url: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set management-url to: %s\n", value)
|
||||
} else {
|
||||
fmt.Printf("Set management-url to: %s\n", value)
|
||||
}
|
||||
case "admin-url":
|
||||
input := internal.ConfigInput{ConfigPath: configPath, AdminURL: value}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set admin-url: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set admin-url to: %s\n", value)
|
||||
} else {
|
||||
fmt.Printf("Set admin-url to: %s\n", value)
|
||||
}
|
||||
case "interface-name":
|
||||
if err := parseInterfaceName(value); err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, InterfaceName: &value}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set interface-name: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set interface-name to: %s\n", value)
|
||||
} else {
|
||||
fmt.Printf("Set interface-name to: %s\n", value)
|
||||
}
|
||||
case "external-ip-map":
|
||||
var ips []string
|
||||
if value == "" {
|
||||
ips = []string{}
|
||||
} else {
|
||||
ips = strings.Split(value, ",")
|
||||
}
|
||||
if err := validateNATExternalIPs(ips); err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, NATExternalIPs: ips}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set external-ip-map: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set external-ip-map to: %v\n", ips)
|
||||
} else {
|
||||
fmt.Printf("Set external-ip-map to: %v\n", ips)
|
||||
}
|
||||
case "extra-iface-blacklist":
|
||||
var ifaces []string
|
||||
if value == "" {
|
||||
ifaces = []string{}
|
||||
} else {
|
||||
ifaces = strings.Split(value, ",")
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, ExtraIFaceBlackList: ifaces}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set extra-iface-blacklist: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set extra-iface-blacklist to: %v\n", ifaces)
|
||||
} else {
|
||||
fmt.Printf("Set extra-iface-blacklist to: %v\n", ifaces)
|
||||
}
|
||||
case "dns-resolver-address":
|
||||
if value != "" && !isValidAddrPort(value) {
|
||||
return fmt.Errorf("%s is invalid, it should be formatted as IP:Port string or as an empty string like \"\"", value)
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, CustomDNSAddress: []byte(value)}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set dns-resolver-address: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set dns-resolver-address to: %s\n", value)
|
||||
} else {
|
||||
fmt.Printf("Set dns-resolver-address to: %s\n", value)
|
||||
}
|
||||
case "extra-dns-labels":
|
||||
var labels []string
|
||||
if value == "" {
|
||||
labels = []string{}
|
||||
} else {
|
||||
labels = strings.Split(value, ",")
|
||||
}
|
||||
domains, err := domain.ValidateDomains(labels)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid DNS labels: %v", err)
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DNSLabels: domains}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set extra-dns-labels: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set extra-dns-labels to: %v\n", labels)
|
||||
} else {
|
||||
fmt.Printf("Set extra-dns-labels to: %v\n", labels)
|
||||
}
|
||||
case "preshared-key":
|
||||
input := internal.ConfigInput{ConfigPath: configPath, PreSharedKey: &value}
|
||||
_, err := internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set preshared-key: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set preshared-key to: %s\n", value)
|
||||
} else {
|
||||
fmt.Printf("Set preshared-key to: %s\n", value)
|
||||
}
|
||||
case "hostname":
|
||||
// Hostname is not persisted in config, so just print a warning
|
||||
if cmd != nil {
|
||||
cmd.Printf("Warning: hostname is not persisted in config. Use --hostname with up command.\n")
|
||||
} else {
|
||||
fmt.Printf("Warning: hostname is not persisted in config. Use --hostname with up command.\n")
|
||||
}
|
||||
case "enable-rosenpass":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, RosenpassEnabled: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set enable-rosenpass: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set enable-rosenpass to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set enable-rosenpass to: %v\n", b)
|
||||
}
|
||||
case "rosenpass-permissive":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, RosenpassPermissive: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set rosenpass-permissive: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set rosenpass-permissive to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set rosenpass-permissive to: %v\n", b)
|
||||
}
|
||||
case "allow-server-ssh":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, ServerSSHAllowed: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set allow-server-ssh: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set allow-server-ssh to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set allow-server-ssh to: %v\n", b)
|
||||
}
|
||||
case "network-monitor":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, NetworkMonitor: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set network-monitor: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set network-monitor to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set network-monitor to: %v\n", b)
|
||||
}
|
||||
case "disable-auto-connect":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DisableAutoConnect: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set disable-auto-connect: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set disable-auto-connect to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set disable-auto-connect to: %v\n", b)
|
||||
}
|
||||
case "disable-client-routes":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DisableClientRoutes: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set disable-client-routes: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set disable-client-routes to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set disable-client-routes to: %v\n", b)
|
||||
}
|
||||
case "disable-server-routes":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DisableServerRoutes: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set disable-server-routes: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set disable-server-routes to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set disable-server-routes to: %v\n", b)
|
||||
}
|
||||
case "disable-dns":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DisableDNS: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set disable-dns: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set disable-dns to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set disable-dns to: %v\n", b)
|
||||
}
|
||||
case "disable-firewall":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DisableFirewall: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set disable-firewall: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set disable-firewall to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set disable-firewall to: %v\n", b)
|
||||
}
|
||||
case "block-lan-access":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, BlockLANAccess: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set block-lan-access: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set block-lan-access to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set block-lan-access to: %v\n", b)
|
||||
}
|
||||
case "block-inbound":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, BlockInbound: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set block-inbound: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set block-inbound to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set block-inbound to: %v\n", b)
|
||||
}
|
||||
case "enable-lazy-connection":
|
||||
b, err := parseBool(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, LazyConnectionEnabled: &b}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set enable-lazy-connection: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set enable-lazy-connection to: %v\n", b)
|
||||
} else {
|
||||
fmt.Printf("Set enable-lazy-connection to: %v\n", b)
|
||||
}
|
||||
case "wireguard-port":
|
||||
p, err := parseUint16(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pi := int(p)
|
||||
input := internal.ConfigInput{ConfigPath: configPath, WireguardPort: &pi}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set wireguard-port: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set wireguard-port to: %d\n", p)
|
||||
} else {
|
||||
fmt.Printf("Set wireguard-port to: %d\n", p)
|
||||
}
|
||||
case "dns-router-interval":
|
||||
d, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid duration: %v", err)
|
||||
}
|
||||
input := internal.ConfigInput{ConfigPath: configPath, DNSRouteInterval: &d}
|
||||
_, err = internal.UpdateOrCreateConfig(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set dns-router-interval: %v", err)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmd.Printf("Set dns-router-interval to: %s\n", d)
|
||||
} else {
|
||||
fmt.Printf("Set dns-router-interval to: %s\n", d)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown setting: %s", setting)
|
||||
}
|
||||
|
||||
if cmd != nil {
|
||||
cmd.Println("Configuration updated successfully.")
|
||||
} else {
|
||||
fmt.Println("Configuration updated successfully.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseBool(val string) (bool, error) {
|
||||
v := strings.ToLower(val)
|
||||
if v == "true" || v == "1" {
|
||||
return true, nil
|
||||
}
|
||||
if v == "false" || v == "0" {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("invalid boolean value: %s", val)
|
||||
}
|
||||
|
||||
func parseUint16(val string) (uint16, error) {
|
||||
var p uint16
|
||||
_, err := fmt.Sscanf(val, "%d", &p)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid uint16 value: %s", val)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
@@ -1,162 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSetCommand_AllSettings(t *testing.T) {
|
||||
tempFile, err := os.CreateTemp("", "config.json")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
// Write empty JSON object to the config file to avoid JSON parse errors
|
||||
_, err = tempFile.WriteString("{}")
|
||||
require.NoError(t, err)
|
||||
tempFile.Close()
|
||||
|
||||
configPath = tempFile.Name()
|
||||
|
||||
tests := []struct {
|
||||
setting string
|
||||
value string
|
||||
verify func(*testing.T, *internal.Config)
|
||||
wantErr bool
|
||||
}{
|
||||
{"management-url", "https://test.mgmt:443", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, "https://test.mgmt:443", c.ManagementURL.String())
|
||||
}, false},
|
||||
{"admin-url", "https://test.admin:443", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, "https://test.admin:443", c.AdminURL.String())
|
||||
}, false},
|
||||
{"interface-name", "utun99", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, "utun99", c.WgIface)
|
||||
}, false},
|
||||
{"external-ip-map", "12.34.56.78,12.34.56.79", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, []string{"12.34.56.78", "12.34.56.79"}, c.NATExternalIPs)
|
||||
}, false},
|
||||
{"extra-iface-blacklist", "eth1,eth2", func(t *testing.T, c *internal.Config) {
|
||||
require.Contains(t, c.IFaceBlackList, "eth1")
|
||||
require.Contains(t, c.IFaceBlackList, "eth2")
|
||||
}, false},
|
||||
{"dns-resolver-address", "127.0.0.1:5053", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, "127.0.0.1:5053", c.CustomDNSAddress)
|
||||
}, false},
|
||||
{"extra-dns-labels", "vpc1,mgmt1", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, strings.Contains(c.DNSLabels.SafeString(), "vpc1"))
|
||||
require.True(t, strings.Contains(c.DNSLabels.SafeString(), "mgmt1"))
|
||||
}, false},
|
||||
{"preshared-key", "testkey", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, "testkey", c.PreSharedKey)
|
||||
}, false},
|
||||
{"enable-rosenpass", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, c.RosenpassEnabled)
|
||||
}, false},
|
||||
{"rosenpass-permissive", "false", func(t *testing.T, c *internal.Config) {
|
||||
require.False(t, c.RosenpassPermissive)
|
||||
}, false},
|
||||
{"allow-server-ssh", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.NotNil(t, c.ServerSSHAllowed)
|
||||
require.True(t, *c.ServerSSHAllowed)
|
||||
}, false},
|
||||
{"network-monitor", "false", func(t *testing.T, c *internal.Config) {
|
||||
require.NotNil(t, c.NetworkMonitor)
|
||||
require.False(t, *c.NetworkMonitor)
|
||||
}, false},
|
||||
{"disable-auto-connect", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, c.DisableAutoConnect)
|
||||
}, false},
|
||||
{"disable-client-routes", "false", func(t *testing.T, c *internal.Config) {
|
||||
require.False(t, c.DisableClientRoutes)
|
||||
}, false},
|
||||
{"disable-server-routes", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, c.DisableServerRoutes)
|
||||
}, false},
|
||||
{"disable-dns", "false", func(t *testing.T, c *internal.Config) {
|
||||
require.False(t, c.DisableDNS)
|
||||
}, false},
|
||||
{"disable-firewall", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, c.DisableFirewall)
|
||||
}, false},
|
||||
{"block-lan-access", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, c.BlockLANAccess)
|
||||
}, false},
|
||||
{"block-inbound", "false", func(t *testing.T, c *internal.Config) {
|
||||
require.False(t, c.BlockInbound)
|
||||
}, false},
|
||||
{"enable-lazy-connection", "true", func(t *testing.T, c *internal.Config) {
|
||||
require.True(t, c.LazyConnectionEnabled)
|
||||
}, false},
|
||||
{"wireguard-port", "51820", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, 51820, c.WgPort)
|
||||
}, false},
|
||||
{"dns-router-interval", "2m", func(t *testing.T, c *internal.Config) {
|
||||
require.Equal(t, 2*time.Minute, c.DNSRouteInterval)
|
||||
}, false},
|
||||
// Invalid cases
|
||||
{"enable-rosenpass", "notabool", nil, true},
|
||||
{"wireguard-port", "notanint", nil, true},
|
||||
{"dns-router-interval", "notaduration", nil, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.setting+"="+tt.value, func(t *testing.T) {
|
||||
args := []string{tt.setting, tt.value}
|
||||
err := setFunc(nil, args)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
config, err := internal.ReadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, config)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCommand_EnvVars(t *testing.T) {
|
||||
tempFile, err := os.CreateTemp("", "config.json")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
// Write empty JSON object to the config file to avoid JSON parse errors
|
||||
_, err = tempFile.WriteString("{}")
|
||||
require.NoError(t, err)
|
||||
tempFile.Close()
|
||||
|
||||
configPath = tempFile.Name()
|
||||
|
||||
os.Setenv("NB_INTERFACE_NAME", "utun77")
|
||||
defer os.Unsetenv("NB_INTERFACE_NAME")
|
||||
args := []string{"interface-name", "utun99"}
|
||||
err = setFunc(nil, args)
|
||||
require.NoError(t, err)
|
||||
config, err := internal.ReadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "utun77", config.WgIface)
|
||||
|
||||
os.Unsetenv("NB_INTERFACE_NAME")
|
||||
os.Setenv("WT_INTERFACE_NAME", "utun88")
|
||||
defer os.Unsetenv("WT_INTERFACE_NAME")
|
||||
err = setFunc(nil, args)
|
||||
require.NoError(t, err)
|
||||
config, err = internal.ReadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "utun88", config.WgIface)
|
||||
|
||||
os.Unsetenv("WT_INTERFACE_NAME")
|
||||
// No env var, should use CLI value
|
||||
err = setFunc(nil, args)
|
||||
require.NoError(t, err)
|
||||
config, err = internal.ReadConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "utun99", config.WgIface)
|
||||
}
|
||||
@@ -2,20 +2,107 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
type peerStateDetailOutput struct {
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
Status string `json:"status" yaml:"status"`
|
||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
||||
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
|
||||
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
|
||||
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
|
||||
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
|
||||
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
||||
Latency time.Duration `json:"latency" yaml:"latency"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
Networks []string `json:"networks" yaml:"networks"`
|
||||
}
|
||||
|
||||
type peersStateOutput struct {
|
||||
Total int `json:"total" yaml:"total"`
|
||||
Connected int `json:"connected" yaml:"connected"`
|
||||
Details []peerStateDetailOutput `json:"details" yaml:"details"`
|
||||
}
|
||||
|
||||
type signalStateOutput struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
Connected bool `json:"connected" yaml:"connected"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type managementStateOutput struct {
|
||||
URL string `json:"url" yaml:"url"`
|
||||
Connected bool `json:"connected" yaml:"connected"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type relayStateOutputDetail struct {
|
||||
URI string `json:"uri" yaml:"uri"`
|
||||
Available bool `json:"available" yaml:"available"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type relayStateOutput struct {
|
||||
Total int `json:"total" yaml:"total"`
|
||||
Available int `json:"available" yaml:"available"`
|
||||
Details []relayStateOutputDetail `json:"details" yaml:"details"`
|
||||
}
|
||||
|
||||
type iceCandidateType struct {
|
||||
Local string `json:"local" yaml:"local"`
|
||||
Remote string `json:"remote" yaml:"remote"`
|
||||
}
|
||||
|
||||
type nsServerGroupStateOutput struct {
|
||||
Servers []string `json:"servers" yaml:"servers"`
|
||||
Domains []string `json:"domains" yaml:"domains"`
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type statusOutputOverview struct {
|
||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
||||
Relays relayStateOutput `json:"relays" yaml:"relays"`
|
||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
||||
Routes []string `json:"routes" yaml:"routes"`
|
||||
Networks []string `json:"networks" yaml:"networks"`
|
||||
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
||||
}
|
||||
|
||||
var (
|
||||
detailFlag bool
|
||||
ipv4Flag bool
|
||||
@@ -44,7 +131,7 @@ func init() {
|
||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -69,10 +156,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
status := resp.GetStatus()
|
||||
|
||||
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
||||
status == string(internal.StatusSessionExpired) {
|
||||
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
||||
cmd.Printf("Daemon status: %s\n\n"+
|
||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||
" netbird up \n\n"+
|
||||
@@ -89,17 +173,18 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
|
||||
outputInformationHolder := convertToStatusOutputOverview(resp)
|
||||
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
||||
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
||||
case jsonFlag:
|
||||
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
||||
statusOutputString, err = parseToJSON(outputInformationHolder)
|
||||
case yamlFlag:
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
statusOutputString, err = parseToYAML(outputInformationHolder)
|
||||
default:
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -120,7 +205,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
|
||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||
}
|
||||
@@ -129,13 +214,14 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
||||
}
|
||||
|
||||
func parseFilters() error {
|
||||
|
||||
switch strings.ToLower(statusFilter) {
|
||||
case "", "idle", "connecting", "connected":
|
||||
case "", "disconnected", "connected":
|
||||
if strings.ToLower(statusFilter) != "" {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
|
||||
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
||||
}
|
||||
|
||||
if len(ipsFilter) > 0 {
|
||||
@@ -165,6 +251,175 @@ func enableDetailFlagWhenFilterFlag() {
|
||||
}
|
||||
}
|
||||
|
||||
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
|
||||
pbFullStatus := resp.GetFullStatus()
|
||||
|
||||
managementState := pbFullStatus.GetManagementState()
|
||||
managementOverview := managementStateOutput{
|
||||
URL: managementState.GetURL(),
|
||||
Connected: managementState.GetConnected(),
|
||||
Error: managementState.Error,
|
||||
}
|
||||
|
||||
signalState := pbFullStatus.GetSignalState()
|
||||
signalOverview := signalStateOutput{
|
||||
URL: signalState.GetURL(),
|
||||
Connected: signalState.GetConnected(),
|
||||
Error: signalState.Error,
|
||||
}
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
|
||||
|
||||
overview := statusOutputOverview{
|
||||
Peers: peersOverview,
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: resp.GetDaemonVersion(),
|
||||
ManagementState: managementOverview,
|
||||
SignalState: signalOverview,
|
||||
Relays: relayOverview,
|
||||
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
||||
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
||||
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
||||
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
|
||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
||||
}
|
||||
|
||||
if anonymizeFlag {
|
||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||
anonymizeOverview(anonymizer, &overview)
|
||||
}
|
||||
|
||||
return overview
|
||||
}
|
||||
|
||||
func mapRelays(relays []*proto.RelayState) relayStateOutput {
|
||||
var relayStateDetail []relayStateOutputDetail
|
||||
|
||||
var relaysAvailable int
|
||||
for _, relay := range relays {
|
||||
available := relay.GetAvailable()
|
||||
relayStateDetail = append(relayStateDetail,
|
||||
relayStateOutputDetail{
|
||||
URI: relay.URI,
|
||||
Available: available,
|
||||
Error: relay.GetError(),
|
||||
},
|
||||
)
|
||||
|
||||
if available {
|
||||
relaysAvailable++
|
||||
}
|
||||
}
|
||||
|
||||
return relayStateOutput{
|
||||
Total: len(relays),
|
||||
Available: relaysAvailable,
|
||||
Details: relayStateDetail,
|
||||
}
|
||||
}
|
||||
|
||||
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
|
||||
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
|
||||
for _, pbNsGroupServer := range servers {
|
||||
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
|
||||
Servers: pbNsGroupServer.GetServers(),
|
||||
Domains: pbNsGroupServer.GetDomains(),
|
||||
Enabled: pbNsGroupServer.GetEnabled(),
|
||||
Error: pbNsGroupServer.GetError(),
|
||||
})
|
||||
}
|
||||
return mappedNSGroups
|
||||
}
|
||||
|
||||
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
||||
var peersStateDetail []peerStateDetailOutput
|
||||
peersConnected := 0
|
||||
for _, pbPeerState := range peers {
|
||||
localICE := ""
|
||||
remoteICE := ""
|
||||
localICEEndpoint := ""
|
||||
remoteICEEndpoint := ""
|
||||
relayServerAddress := ""
|
||||
connType := ""
|
||||
lastHandshake := time.Time{}
|
||||
transferReceived := int64(0)
|
||||
transferSent := int64(0)
|
||||
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
||||
continue
|
||||
}
|
||||
if isPeerConnected {
|
||||
peersConnected++
|
||||
|
||||
localICE = pbPeerState.GetLocalIceCandidateType()
|
||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
|
||||
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
|
||||
connType = "P2P"
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
relayServerAddress = pbPeerState.GetRelayAddress()
|
||||
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
||||
transferReceived = pbPeerState.GetBytesRx()
|
||||
transferSent = pbPeerState.GetBytesTx()
|
||||
}
|
||||
|
||||
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
||||
peerState := peerStateDetailOutput{
|
||||
IP: pbPeerState.GetIP(),
|
||||
PubKey: pbPeerState.GetPubKey(),
|
||||
Status: pbPeerState.GetConnStatus(),
|
||||
LastStatusUpdate: timeLocal,
|
||||
ConnType: connType,
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: localICE,
|
||||
Remote: remoteICE,
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: localICEEndpoint,
|
||||
Remote: remoteICEEndpoint,
|
||||
},
|
||||
RelayAddress: relayServerAddress,
|
||||
FQDN: pbPeerState.GetFqdn(),
|
||||
LastWireguardHandshake: lastHandshake,
|
||||
TransferReceived: transferReceived,
|
||||
TransferSent: transferSent,
|
||||
Latency: pbPeerState.GetLatency().AsDuration(),
|
||||
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
||||
Routes: pbPeerState.GetNetworks(),
|
||||
Networks: pbPeerState.GetNetworks(),
|
||||
}
|
||||
|
||||
peersStateDetail = append(peersStateDetail, peerState)
|
||||
}
|
||||
|
||||
sortPeersByIP(peersStateDetail)
|
||||
|
||||
peersOverview := peersStateOutput{
|
||||
Total: len(peersStateDetail),
|
||||
Connected: peersConnected,
|
||||
Details: peersStateDetail,
|
||||
}
|
||||
return peersOverview
|
||||
}
|
||||
|
||||
func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
|
||||
if len(peersStateDetail) > 0 {
|
||||
sort.SliceStable(peersStateDetail, func(i, j int) bool {
|
||||
iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
|
||||
jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
|
||||
return iAddr.Compare(jAddr) == -1
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func parseInterfaceIP(interfaceIP string) string {
|
||||
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||
if err != nil {
|
||||
@@ -172,3 +427,452 @@ func parseInterfaceIP(interfaceIP string) string {
|
||||
}
|
||||
return fmt.Sprintf("%s\n", ip)
|
||||
}
|
||||
|
||||
func parseToJSON(overview statusOutputOverview) (string, error) {
|
||||
jsonBytes, err := json.Marshal(overview)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
func parseToYAML(overview statusOutputOverview) (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(overview)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
|
||||
var managementConnString string
|
||||
if overview.ManagementState.Connected {
|
||||
managementConnString = "Connected"
|
||||
if showURL {
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
||||
}
|
||||
} else {
|
||||
managementConnString = "Disconnected"
|
||||
if overview.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
var signalConnString string
|
||||
if overview.SignalState.Connected {
|
||||
signalConnString = "Connected"
|
||||
if showURL {
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
||||
}
|
||||
} else {
|
||||
signalConnString = "Disconnected"
|
||||
if overview.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
interfaceTypeString := "Userspace"
|
||||
interfaceIP := overview.IP
|
||||
if overview.KernelInterface {
|
||||
interfaceTypeString = "Kernel"
|
||||
} else if overview.IP == "" {
|
||||
interfaceTypeString = "N/A"
|
||||
interfaceIP = "N/A"
|
||||
}
|
||||
|
||||
var relaysString string
|
||||
if showRelays {
|
||||
for _, relay := range overview.Relays.Details {
|
||||
available := "Available"
|
||||
reason := ""
|
||||
if !relay.Available {
|
||||
available = "Unavailable"
|
||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
||||
}
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
}
|
||||
} else {
|
||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||
}
|
||||
|
||||
networks := "-"
|
||||
if len(overview.Networks) > 0 {
|
||||
sort.Strings(overview.Networks)
|
||||
networks = strings.Join(overview.Networks, ", ")
|
||||
}
|
||||
|
||||
var dnsServersString string
|
||||
if showNameServers {
|
||||
for _, nsServerGroup := range overview.NSServerGroups {
|
||||
enabled := "Available"
|
||||
if !nsServerGroup.Enabled {
|
||||
enabled = "Unavailable"
|
||||
}
|
||||
errorString := ""
|
||||
if nsServerGroup.Error != "" {
|
||||
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
|
||||
errorString = strings.TrimSpace(errorString)
|
||||
}
|
||||
|
||||
domainsString := strings.Join(nsServerGroup.Domains, ", ")
|
||||
if domainsString == "" {
|
||||
domainsString = "." // Show "." for the default zone
|
||||
}
|
||||
dnsServersString += fmt.Sprintf(
|
||||
"\n [%s] for [%s] is %s%s",
|
||||
strings.Join(nsServerGroup.Servers, ", "),
|
||||
domainsString,
|
||||
enabled,
|
||||
errorString,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if overview.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
if overview.RosenpassPermissive {
|
||||
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
goarm := ""
|
||||
if goarch == "arm" {
|
||||
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
|
||||
}
|
||||
|
||||
summary := fmt.Sprintf(
|
||||
"OS: %s\n"+
|
||||
"Daemon version: %s\n"+
|
||||
"CLI version: %s\n"+
|
||||
"Management: %s\n"+
|
||||
"Signal: %s\n"+
|
||||
"Relays: %s\n"+
|
||||
"Nameservers: %s\n"+
|
||||
"FQDN: %s\n"+
|
||||
"NetBird IP: %s\n"+
|
||||
"Interface type: %s\n"+
|
||||
"Quantum resistance: %s\n"+
|
||||
"Routes: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
overview.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
managementConnString,
|
||||
signalConnString,
|
||||
relaysString,
|
||||
dnsServersString,
|
||||
overview.FQDN,
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
networks,
|
||||
networks,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
}
|
||||
|
||||
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||
summary := parseGeneralSummary(overview, true, true, true)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
"%s\n"+
|
||||
"%s",
|
||||
parsedPeersString,
|
||||
summary,
|
||||
)
|
||||
}
|
||||
|
||||
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
|
||||
var (
|
||||
peersString = ""
|
||||
)
|
||||
|
||||
for _, peerState := range peers.Details {
|
||||
|
||||
localICE := "-"
|
||||
if peerState.IceCandidateType.Local != "" {
|
||||
localICE = peerState.IceCandidateType.Local
|
||||
}
|
||||
|
||||
remoteICE := "-"
|
||||
if peerState.IceCandidateType.Remote != "" {
|
||||
remoteICE = peerState.IceCandidateType.Remote
|
||||
}
|
||||
|
||||
localICEEndpoint := "-"
|
||||
if peerState.IceCandidateEndpoint.Local != "" {
|
||||
localICEEndpoint = peerState.IceCandidateEndpoint.Local
|
||||
}
|
||||
|
||||
remoteICEEndpoint := "-"
|
||||
if peerState.IceCandidateEndpoint.Remote != "" {
|
||||
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if rosenpassEnabled {
|
||||
if peerState.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
} else {
|
||||
if rosenpassPermissive {
|
||||
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
|
||||
} else {
|
||||
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if peerState.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
|
||||
}
|
||||
}
|
||||
|
||||
networks := "-"
|
||||
if len(peerState.Networks) > 0 {
|
||||
sort.Strings(peerState.Networks)
|
||||
networks = strings.Join(peerState.Networks, ", ")
|
||||
}
|
||||
|
||||
peerString := fmt.Sprintf(
|
||||
"\n %s:\n"+
|
||||
" NetBird IP: %s\n"+
|
||||
" Public key: %s\n"+
|
||||
" Status: %s\n"+
|
||||
" -- detail --\n"+
|
||||
" Connection type: %s\n"+
|
||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
||||
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
|
||||
" Relay server address: %s\n"+
|
||||
" Last connection update: %s\n"+
|
||||
" Last WireGuard handshake: %s\n"+
|
||||
" Transfer status (received/sent) %s/%s\n"+
|
||||
" Quantum resistance: %s\n"+
|
||||
" Routes: %s\n"+
|
||||
" Networks: %s\n"+
|
||||
" Latency: %s\n",
|
||||
peerState.FQDN,
|
||||
peerState.IP,
|
||||
peerState.PubKey,
|
||||
peerState.Status,
|
||||
peerState.ConnType,
|
||||
localICE,
|
||||
remoteICE,
|
||||
localICEEndpoint,
|
||||
remoteICEEndpoint,
|
||||
peerState.RelayAddress,
|
||||
timeAgo(peerState.LastStatusUpdate),
|
||||
timeAgo(peerState.LastWireguardHandshake),
|
||||
toIEC(peerState.TransferReceived),
|
||||
toIEC(peerState.TransferSent),
|
||||
rosenpassEnabledStatus,
|
||||
networks,
|
||||
networks,
|
||||
peerState.Latency.String(),
|
||||
)
|
||||
|
||||
peersString += peerString
|
||||
}
|
||||
return peersString
|
||||
}
|
||||
|
||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := true
|
||||
|
||||
if statusFilter != "" {
|
||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
||||
if lowerStatusFilter == "disconnected" && isConnected {
|
||||
statusEval = true
|
||||
} else if lowerStatusFilter == "connected" && !isConnected {
|
||||
statusEval = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(ipsFilter) > 0 {
|
||||
_, ok := ipsFilterMap[peerState.IP]
|
||||
if !ok {
|
||||
ipEval = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(prefixNamesFilter) > 0 {
|
||||
for prefixNameFilter := range prefixNamesFilterMap {
|
||||
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
||||
nameEval = false
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nameEval = false
|
||||
}
|
||||
|
||||
return statusEval || ipEval || nameEval
|
||||
}
|
||||
|
||||
func toIEC(b int64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB",
|
||||
float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
|
||||
count := 0
|
||||
for _, server := range dnsServers {
|
||||
if server.Enabled {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
|
||||
func timeAgo(t time.Time) string {
|
||||
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
|
||||
return "-"
|
||||
}
|
||||
duration := time.Since(t)
|
||||
switch {
|
||||
case duration < time.Second:
|
||||
return "Now"
|
||||
case duration < time.Minute:
|
||||
seconds := int(duration.Seconds())
|
||||
if seconds == 1 {
|
||||
return "1 second ago"
|
||||
}
|
||||
return fmt.Sprintf("%d seconds ago", seconds)
|
||||
case duration < time.Hour:
|
||||
minutes := int(duration.Minutes())
|
||||
seconds := int(duration.Seconds()) % 60
|
||||
if minutes == 1 {
|
||||
if seconds == 1 {
|
||||
return "1 minute, 1 second ago"
|
||||
} else if seconds > 0 {
|
||||
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
|
||||
}
|
||||
return "1 minute ago"
|
||||
}
|
||||
if seconds > 0 {
|
||||
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
|
||||
}
|
||||
return fmt.Sprintf("%d minutes ago", minutes)
|
||||
case duration < 24*time.Hour:
|
||||
hours := int(duration.Hours())
|
||||
minutes := int(duration.Minutes()) % 60
|
||||
if hours == 1 {
|
||||
if minutes == 1 {
|
||||
return "1 hour, 1 minute ago"
|
||||
} else if minutes > 0 {
|
||||
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
|
||||
}
|
||||
return "1 hour ago"
|
||||
}
|
||||
if minutes > 0 {
|
||||
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%d hours ago", hours)
|
||||
}
|
||||
|
||||
days := int(duration.Hours()) / 24
|
||||
hours := int(duration.Hours()) % 24
|
||||
if days == 1 {
|
||||
if hours == 1 {
|
||||
return "1 day, 1 hour ago"
|
||||
} else if hours > 0 {
|
||||
return fmt.Sprintf("1 day, %d hours ago", hours)
|
||||
}
|
||||
return "1 day ago"
|
||||
}
|
||||
if hours > 0 {
|
||||
return fmt.Sprintf("%d days, %d hours ago", days, hours)
|
||||
}
|
||||
return fmt.Sprintf("%d days ago", days)
|
||||
}
|
||||
|
||||
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
||||
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
|
||||
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
|
||||
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
|
||||
}
|
||||
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
||||
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
||||
}
|
||||
|
||||
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
||||
|
||||
for i, route := range peer.Networks {
|
||||
peer.Networks[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Networks {
|
||||
peer.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
||||
}
|
||||
|
||||
for i, route := range peer.Routes {
|
||||
peer.Routes[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
|
||||
for i, peer := range overview.Peers.Details {
|
||||
peer := peer
|
||||
anonymizePeerDetail(a, &peer)
|
||||
overview.Peers.Details[i] = peer
|
||||
}
|
||||
|
||||
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
|
||||
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
|
||||
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
|
||||
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
|
||||
|
||||
overview.IP = a.AnonymizeIPString(overview.IP)
|
||||
for i, detail := range overview.Relays.Details {
|
||||
detail.URI = a.AnonymizeURI(detail.URI)
|
||||
detail.Error = a.AnonymizeString(detail.Error)
|
||||
overview.Relays.Details[i] = detail
|
||||
}
|
||||
|
||||
for i, nsGroup := range overview.NSServerGroups {
|
||||
for j, domain := range nsGroup.Domains {
|
||||
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
|
||||
}
|
||||
for j, ns := range nsGroup.Servers {
|
||||
host, port, err := net.SplitHostPort(ns)
|
||||
if err == nil {
|
||||
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, route := range overview.Networks {
|
||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
for i, route := range overview.Routes {
|
||||
overview.Routes[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,597 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
func init() {
|
||||
loc, err := time.LoadLocation("UTC")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
time.Local = loc
|
||||
}
|
||||
|
||||
var resp = &proto.StatusResponse{
|
||||
Status: "Connected",
|
||||
FullStatus: &proto.FullStatus{
|
||||
Peers: []*proto.PeerState{
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
PubKey: "Pubkey1",
|
||||
Fqdn: "peer-1.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
||||
Relayed: false,
|
||||
LocalIceCandidateType: "",
|
||||
RemoteIceCandidateType: "",
|
||||
LocalIceCandidateEndpoint: "",
|
||||
RemoteIceCandidateEndpoint: "",
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
||||
BytesRx: 200,
|
||||
BytesTx: 100,
|
||||
Networks: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Latency: durationpb.New(time.Duration(10000000)),
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
PubKey: "Pubkey2",
|
||||
Fqdn: "peer-2.awesome-domain.com",
|
||||
ConnStatus: "Connected",
|
||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
||||
Relayed: true,
|
||||
LocalIceCandidateType: "relay",
|
||||
RemoteIceCandidateType: "prflx",
|
||||
LocalIceCandidateEndpoint: "10.0.0.1:10001",
|
||||
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
|
||||
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
|
||||
BytesRx: 2000,
|
||||
BytesTx: 1000,
|
||||
Latency: durationpb.New(time.Duration(10000000)),
|
||||
},
|
||||
},
|
||||
ManagementState: &proto.ManagementState{
|
||||
URL: "my-awesome-management.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
SignalState: &proto.SignalState{
|
||||
URL: "my-awesome-signal.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
Relays: []*proto.RelayState{
|
||||
{
|
||||
URI: "stun:my-awesome-stun.com:3478",
|
||||
Available: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
Available: false,
|
||||
Error: "context: deadline exceeded",
|
||||
},
|
||||
},
|
||||
LocalPeerState: &proto.LocalPeerState{
|
||||
IP: "192.168.178.100/16",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
Fqdn: "some-localhost.awesome-domain.com",
|
||||
Networks: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
},
|
||||
DnsServers: []*proto.NSGroupState{
|
||||
{
|
||||
Servers: []string{
|
||||
"8.8.8.8:53",
|
||||
},
|
||||
Domains: nil,
|
||||
Enabled: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
Servers: []string{
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53",
|
||||
},
|
||||
Domains: []string{
|
||||
"example.com",
|
||||
"example.net",
|
||||
},
|
||||
Enabled: false,
|
||||
Error: "timeout",
|
||||
},
|
||||
},
|
||||
},
|
||||
DaemonVersion: "0.14.1",
|
||||
}
|
||||
|
||||
var overview = statusOutputOverview{
|
||||
Peers: peersStateOutput{
|
||||
Total: 2,
|
||||
Connected: 2,
|
||||
Details: []peerStateDetailOutput{
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
PubKey: "Pubkey1",
|
||||
FQDN: "peer-1.awesome-domain.com",
|
||||
Status: "Connected",
|
||||
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
|
||||
ConnType: "P2P",
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: "",
|
||||
Remote: "",
|
||||
},
|
||||
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
|
||||
TransferReceived: 200,
|
||||
TransferSent: 100,
|
||||
Routes: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Networks: []string{
|
||||
"10.1.0.0/24",
|
||||
},
|
||||
Latency: time.Duration(10000000),
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
PubKey: "Pubkey2",
|
||||
FQDN: "peer-2.awesome-domain.com",
|
||||
Status: "Connected",
|
||||
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
|
||||
ConnType: "Relayed",
|
||||
IceCandidateType: iceCandidateType{
|
||||
Local: "relay",
|
||||
Remote: "prflx",
|
||||
},
|
||||
IceCandidateEndpoint: iceCandidateType{
|
||||
Local: "10.0.0.1:10001",
|
||||
Remote: "10.0.10.1:10002",
|
||||
},
|
||||
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
|
||||
TransferReceived: 2000,
|
||||
TransferSent: 1000,
|
||||
Latency: time.Duration(10000000),
|
||||
},
|
||||
},
|
||||
},
|
||||
CliVersion: version.NetbirdVersion(),
|
||||
DaemonVersion: "0.14.1",
|
||||
ManagementState: managementStateOutput{
|
||||
URL: "my-awesome-management.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
SignalState: signalStateOutput{
|
||||
URL: "my-awesome-signal.com:443",
|
||||
Connected: true,
|
||||
Error: "",
|
||||
},
|
||||
Relays: relayStateOutput{
|
||||
Total: 2,
|
||||
Available: 1,
|
||||
Details: []relayStateOutputDetail{
|
||||
{
|
||||
URI: "stun:my-awesome-stun.com:3478",
|
||||
Available: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
Available: false,
|
||||
Error: "context: deadline exceeded",
|
||||
},
|
||||
},
|
||||
},
|
||||
IP: "192.168.178.100/16",
|
||||
PubKey: "Some-Pub-Key",
|
||||
KernelInterface: true,
|
||||
FQDN: "some-localhost.awesome-domain.com",
|
||||
NSServerGroups: []nsServerGroupStateOutput{
|
||||
{
|
||||
Servers: []string{
|
||||
"8.8.8.8:53",
|
||||
},
|
||||
Domains: nil,
|
||||
Enabled: true,
|
||||
Error: "",
|
||||
},
|
||||
{
|
||||
Servers: []string{
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53",
|
||||
},
|
||||
Domains: []string{
|
||||
"example.com",
|
||||
"example.net",
|
||||
},
|
||||
Enabled: false,
|
||||
Error: "timeout",
|
||||
},
|
||||
},
|
||||
Routes: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
Networks: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
convertedResult := convertToStatusOutputOverview(resp)
|
||||
|
||||
assert.Equal(t, overview, convertedResult)
|
||||
}
|
||||
|
||||
func TestSortingOfPeers(t *testing.T) {
|
||||
peers := []peerStateDetailOutput{
|
||||
{
|
||||
IP: "192.168.178.104",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.102",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.101",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.105",
|
||||
},
|
||||
{
|
||||
IP: "192.168.178.103",
|
||||
},
|
||||
}
|
||||
|
||||
sortPeersByIP(peers)
|
||||
|
||||
assert.Equal(t, peers[3].IP, "192.168.178.104")
|
||||
}
|
||||
|
||||
func TestParsingToJSON(t *testing.T) {
|
||||
jsonString, _ := parseToJSON(overview)
|
||||
|
||||
//@formatter:off
|
||||
expectedJSONString := `
|
||||
{
|
||||
"peers": {
|
||||
"total": 2,
|
||||
"connected": 2,
|
||||
"details": [
|
||||
{
|
||||
"fqdn": "peer-1.awesome-domain.com",
|
||||
"netbirdIp": "192.168.178.101",
|
||||
"publicKey": "Pubkey1",
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2001-01-01T01:01:01Z",
|
||||
"connectionType": "P2P",
|
||||
"iceCandidateType": {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
},
|
||||
"iceCandidateEndpoint": {
|
||||
"local": "",
|
||||
"remote": ""
|
||||
},
|
||||
"relayAddress": "",
|
||||
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
|
||||
"transferReceived": 200,
|
||||
"transferSent": 100,
|
||||
"latency": 10000000,
|
||||
"quantumResistance": false,
|
||||
"routes": [
|
||||
"10.1.0.0/24"
|
||||
],
|
||||
"networks": [
|
||||
"10.1.0.0/24"
|
||||
]
|
||||
},
|
||||
{
|
||||
"fqdn": "peer-2.awesome-domain.com",
|
||||
"netbirdIp": "192.168.178.102",
|
||||
"publicKey": "Pubkey2",
|
||||
"status": "Connected",
|
||||
"lastStatusUpdate": "2002-02-02T02:02:02Z",
|
||||
"connectionType": "Relayed",
|
||||
"iceCandidateType": {
|
||||
"local": "relay",
|
||||
"remote": "prflx"
|
||||
},
|
||||
"iceCandidateEndpoint": {
|
||||
"local": "10.0.0.1:10001",
|
||||
"remote": "10.0.10.1:10002"
|
||||
},
|
||||
"relayAddress": "",
|
||||
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
|
||||
"transferReceived": 2000,
|
||||
"transferSent": 1000,
|
||||
"latency": 10000000,
|
||||
"quantumResistance": false,
|
||||
"routes": null,
|
||||
"networks": null
|
||||
}
|
||||
]
|
||||
},
|
||||
"cliVersion": "development",
|
||||
"daemonVersion": "0.14.1",
|
||||
"management": {
|
||||
"url": "my-awesome-management.com:443",
|
||||
"connected": true,
|
||||
"error": ""
|
||||
},
|
||||
"signal": {
|
||||
"url": "my-awesome-signal.com:443",
|
||||
"connected": true,
|
||||
"error": ""
|
||||
},
|
||||
"relays": {
|
||||
"total": 2,
|
||||
"available": 1,
|
||||
"details": [
|
||||
{
|
||||
"uri": "stun:my-awesome-stun.com:3478",
|
||||
"available": true,
|
||||
"error": ""
|
||||
},
|
||||
{
|
||||
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
|
||||
"available": false,
|
||||
"error": "context: deadline exceeded"
|
||||
}
|
||||
]
|
||||
},
|
||||
"netbirdIp": "192.168.178.100/16",
|
||||
"publicKey": "Some-Pub-Key",
|
||||
"usesKernelInterface": true,
|
||||
"fqdn": "some-localhost.awesome-domain.com",
|
||||
"quantumResistance": false,
|
||||
"quantumResistancePermissive": false,
|
||||
"routes": [
|
||||
"10.10.0.0/24"
|
||||
],
|
||||
"networks": [
|
||||
"10.10.0.0/24"
|
||||
],
|
||||
"dnsServers": [
|
||||
{
|
||||
"servers": [
|
||||
"8.8.8.8:53"
|
||||
],
|
||||
"domains": null,
|
||||
"enabled": true,
|
||||
"error": ""
|
||||
},
|
||||
{
|
||||
"servers": [
|
||||
"1.1.1.1:53",
|
||||
"2.2.2.2:53"
|
||||
],
|
||||
"domains": [
|
||||
"example.com",
|
||||
"example.net"
|
||||
],
|
||||
"enabled": false,
|
||||
"error": "timeout"
|
||||
}
|
||||
]
|
||||
}`
|
||||
// @formatter:on
|
||||
|
||||
var expectedJSON bytes.Buffer
|
||||
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
|
||||
|
||||
assert.Equal(t, expectedJSON.String(), jsonString)
|
||||
}
|
||||
|
||||
func TestParsingToYAML(t *testing.T) {
|
||||
yaml, _ := parseToYAML(overview)
|
||||
|
||||
expectedYAML :=
|
||||
`peers:
|
||||
total: 2
|
||||
connected: 2
|
||||
details:
|
||||
- fqdn: peer-1.awesome-domain.com
|
||||
netbirdIp: 192.168.178.101
|
||||
publicKey: Pubkey1
|
||||
status: Connected
|
||||
lastStatusUpdate: 2001-01-01T01:01:01Z
|
||||
connectionType: P2P
|
||||
iceCandidateType:
|
||||
local: ""
|
||||
remote: ""
|
||||
iceCandidateEndpoint:
|
||||
local: ""
|
||||
remote: ""
|
||||
relayAddress: ""
|
||||
lastWireguardHandshake: 2001-01-01T01:01:02Z
|
||||
transferReceived: 200
|
||||
transferSent: 100
|
||||
latency: 10ms
|
||||
quantumResistance: false
|
||||
routes:
|
||||
- 10.1.0.0/24
|
||||
networks:
|
||||
- 10.1.0.0/24
|
||||
- fqdn: peer-2.awesome-domain.com
|
||||
netbirdIp: 192.168.178.102
|
||||
publicKey: Pubkey2
|
||||
status: Connected
|
||||
lastStatusUpdate: 2002-02-02T02:02:02Z
|
||||
connectionType: Relayed
|
||||
iceCandidateType:
|
||||
local: relay
|
||||
remote: prflx
|
||||
iceCandidateEndpoint:
|
||||
local: 10.0.0.1:10001
|
||||
remote: 10.0.10.1:10002
|
||||
relayAddress: ""
|
||||
lastWireguardHandshake: 2002-02-02T02:02:03Z
|
||||
transferReceived: 2000
|
||||
transferSent: 1000
|
||||
latency: 10ms
|
||||
quantumResistance: false
|
||||
routes: []
|
||||
networks: []
|
||||
cliVersion: development
|
||||
daemonVersion: 0.14.1
|
||||
management:
|
||||
url: my-awesome-management.com:443
|
||||
connected: true
|
||||
error: ""
|
||||
signal:
|
||||
url: my-awesome-signal.com:443
|
||||
connected: true
|
||||
error: ""
|
||||
relays:
|
||||
total: 2
|
||||
available: 1
|
||||
details:
|
||||
- uri: stun:my-awesome-stun.com:3478
|
||||
available: true
|
||||
error: ""
|
||||
- uri: turns:my-awesome-turn.com:443?transport=tcp
|
||||
available: false
|
||||
error: 'context: deadline exceeded'
|
||||
netbirdIp: 192.168.178.100/16
|
||||
publicKey: Some-Pub-Key
|
||||
usesKernelInterface: true
|
||||
fqdn: some-localhost.awesome-domain.com
|
||||
quantumResistance: false
|
||||
quantumResistancePermissive: false
|
||||
routes:
|
||||
- 10.10.0.0/24
|
||||
networks:
|
||||
- 10.10.0.0/24
|
||||
dnsServers:
|
||||
- servers:
|
||||
- 8.8.8.8:53
|
||||
domains: []
|
||||
enabled: true
|
||||
error: ""
|
||||
- servers:
|
||||
- 1.1.1.1:53
|
||||
- 2.2.2.2:53
|
||||
domains:
|
||||
- example.com
|
||||
- example.net
|
||||
enabled: false
|
||||
error: timeout
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
}
|
||||
|
||||
func TestParsingToDetail(t *testing.T) {
|
||||
// Calculate time ago based on the fixture dates
|
||||
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
|
||||
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
|
||||
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
||||
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
||||
|
||||
detail := parseToFullDetailSummary(overview)
|
||||
|
||||
expectedDetail := fmt.Sprintf(
|
||||
`Peers detail:
|
||||
peer-1.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.101
|
||||
Public key: Pubkey1
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: P2P
|
||||
ICE candidate (Local/Remote): -/-
|
||||
ICE candidate endpoints (Local/Remote): -/-
|
||||
Relay server address:
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 200 B/100 B
|
||||
Quantum resistance: false
|
||||
Routes: 10.1.0.0/24
|
||||
Networks: 10.1.0.0/24
|
||||
Latency: 10ms
|
||||
|
||||
peer-2.awesome-domain.com:
|
||||
NetBird IP: 192.168.178.102
|
||||
Public key: Pubkey2
|
||||
Status: Connected
|
||||
-- detail --
|
||||
Connection type: Relayed
|
||||
ICE candidate (Local/Remote): relay/prflx
|
||||
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
|
||||
Relay server address:
|
||||
Last connection update: %s
|
||||
Last WireGuard handshake: %s
|
||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
||||
Quantum resistance: false
|
||||
Routes: -
|
||||
Networks: -
|
||||
Latency: 10ms
|
||||
|
||||
OS: %s/%s
|
||||
Daemon version: 0.14.1
|
||||
CLI version: %s
|
||||
Management: Connected to my-awesome-management.com:443
|
||||
Signal: Connected to my-awesome-signal.com:443
|
||||
Relays:
|
||||
[stun:my-awesome-stun.com:3478] is Available
|
||||
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
|
||||
Nameservers:
|
||||
[8.8.8.8:53] for [.] is Available
|
||||
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
|
||||
FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
|
||||
assert.Equal(t, expectedDetail, detail)
|
||||
}
|
||||
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := parseGeneralSummary(overview, false, false, false)
|
||||
|
||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
||||
Daemon version: 0.14.1
|
||||
CLI version: development
|
||||
Management: Connected
|
||||
Signal: Connected
|
||||
Relays: 1/2 Available
|
||||
Nameservers: 1/2 Available
|
||||
FQDN: some-localhost.awesome-domain.com
|
||||
NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Routes: 10.10.0.0/24
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedString, shortVersion)
|
||||
}
|
||||
|
||||
func TestParsingOfIP(t *testing.T) {
|
||||
InterfaceIP := "192.168.178.123/16"
|
||||
|
||||
@@ -13,3 +599,31 @@ func TestParsingOfIP(t *testing.T) {
|
||||
|
||||
assert.Equal(t, "192.168.178.123\n", parsedIP)
|
||||
}
|
||||
|
||||
func TestTimeAgo(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input time.Time
|
||||
expected string
|
||||
}{
|
||||
{"Now", now, "Now"},
|
||||
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
|
||||
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
|
||||
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
|
||||
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
|
||||
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
|
||||
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
|
||||
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
|
||||
{"Zero time", time.Time{}, "-"},
|
||||
{"Unix zero time", time.Unix(0, 0), "-"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := timeAgo(tc.input)
|
||||
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@ const (
|
||||
disableServerRoutesFlag = "disable-server-routes"
|
||||
disableDNSFlag = "disable-dns"
|
||||
disableFirewallFlag = "disable-firewall"
|
||||
blockLANAccessFlag = "block-lan-access"
|
||||
blockInboundFlag = "block-inbound"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -15,8 +13,6 @@ var (
|
||||
disableServerRoutes bool
|
||||
disableDNS bool
|
||||
disableFirewall bool
|
||||
blockLANAccess bool
|
||||
blockInbound bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -32,11 +28,4 @@ func init() {
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
|
||||
"Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||
"This overrides any policies received from the management service.")
|
||||
}
|
||||
|
||||
@@ -6,17 +6,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -34,7 +30,7 @@ import (
|
||||
|
||||
func startTestingServices(t *testing.T) string {
|
||||
t.Helper()
|
||||
config := &types.Config{}
|
||||
config := &mgmt.Config{}
|
||||
_, err := util.ReadJson("../testdata/management.json", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -69,7 +65,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
return s, lis
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
@@ -92,24 +88,14 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
|
||||
Example: `
|
||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
|
||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||
Args: cobra.ExactArgs(3),
|
||||
RunE: tracePacket,
|
||||
@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||
cmd.Printf("Packet trace %s:%d → %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||
|
||||
for _, stage := range resp.Stages {
|
||||
if stage.ForwardingDetails != nil {
|
||||
|
||||
303
client/cmd/up.go
303
client/cmd/up.go
@@ -20,7 +20,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -30,20 +29,9 @@ const (
|
||||
interfaceInputType
|
||||
)
|
||||
|
||||
const (
|
||||
dnsLabelsFlag = "extra-dns-labels"
|
||||
|
||||
noBrowserFlag = "no-browser"
|
||||
noBrowserDesc = "do not open the browser for SSO login"
|
||||
)
|
||||
|
||||
var (
|
||||
foregroundMode bool
|
||||
dnsLabels []string
|
||||
dnsLabelsValidated domain.List
|
||||
noBrowser bool
|
||||
|
||||
upCmd = &cobra.Command{
|
||||
foregroundMode bool
|
||||
upCmd = &cobra.Command{
|
||||
Use: "up",
|
||||
Short: "install, login and start Netbird client",
|
||||
RunE: upFunc,
|
||||
@@ -55,22 +43,12 @@ func init() {
|
||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
|
||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||
)
|
||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||
|
||||
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||
`Sets DNS labels`+
|
||||
`You can specify a comma-separated list of up to 32 labels. `+
|
||||
`An empty string "" clears the previous configuration. `+
|
||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||
`or --extra-dns-labels ""`,
|
||||
)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||
}
|
||||
|
||||
func upFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -89,11 +67,6 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsLabelsValidated, err = validateDnsLabels(dnsLabels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
if hostName != "" {
|
||||
@@ -118,124 +91,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
return err
|
||||
}
|
||||
|
||||
ic, err := setupConfig(customDNSAddressConverted, cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup config: %v", err)
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := internal.UpdateOrCreateConfig(*ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil)
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed closing daemon gRPC client connection %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||
}
|
||||
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
cmd.Println("Already connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get setup key: %v", err)
|
||||
}
|
||||
|
||||
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup login request: %v", err)
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
var backOffErr error
|
||||
loginResp, backOffErr = client.Login(ctx, loginRequest)
|
||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||
s.Code() == codes.PermissionDenied ||
|
||||
s.Code() == codes.NotFound ||
|
||||
s.Code() == codes.Unimplemented) {
|
||||
loginErr = backOffErr
|
||||
return nil
|
||||
}
|
||||
return backOffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
if loginErr != nil {
|
||||
return fmt.Errorf("login failed: %v", loginErr)
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
}
|
||||
cmd.Println("Connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
|
||||
ic := internal.ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
@@ -243,7 +98,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*interna
|
||||
NATExternalIPs: natExternalIPs,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||
DNSLabels: dnsLabelsValidated,
|
||||
}
|
||||
|
||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||
@@ -260,7 +114,7 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*interna
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
ic.InterfaceName = &interfaceName
|
||||
}
|
||||
@@ -311,29 +165,81 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*interna
|
||||
ic.BlockLANAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
ic.BlockInbound = &blockInbound
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||
config, err := internal.UpdateOrCreateConfig(ic)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
return &ic, nil
|
||||
|
||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
return connectClient.Run(nil)
|
||||
}
|
||||
|
||||
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed closing daemon gRPC client connection %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||
}
|
||||
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
cmd.Println("Already connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
providedSetupKey, err := getSetupKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||
DnsLabels: dnsLabels,
|
||||
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
AdminURL: adminURL,
|
||||
NatExternalIPs: natExternalIPs,
|
||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||
CustomDNSAddress: customDNSAddressConverted,
|
||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
@@ -358,7 +264,7 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
loginRequest.InterfaceName = &interfaceName
|
||||
}
|
||||
@@ -393,14 +299,45 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.BlockLanAccess = &blockLANAccess
|
||||
}
|
||||
|
||||
if cmd.Flag(blockInboundFlag).Changed {
|
||||
loginRequest.BlockInbound = &blockInbound
|
||||
var loginErr error
|
||||
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
err = WithBackOff(func() error {
|
||||
var backOffErr error
|
||||
loginResp, backOffErr = client.Login(ctx, &loginRequest)
|
||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||
s.Code() == codes.PermissionDenied ||
|
||||
s.Code() == codes.NotFound ||
|
||||
s.Code() == codes.Unimplemented) {
|
||||
loginErr = backOffErr
|
||||
return nil
|
||||
}
|
||||
return backOffErr
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||
if loginErr != nil {
|
||||
return fmt.Errorf("login failed: %v", loginErr)
|
||||
}
|
||||
return &loginRequest, nil
|
||||
|
||||
if loginResp.NeedsSSOLogin {
|
||||
|
||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
||||
|
||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
}
|
||||
cmd.Println("Connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateNATExternalIPs(list []string) error {
|
||||
@@ -493,24 +430,6 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func validateDnsLabels(labels []string) (domain.List, error) {
|
||||
var (
|
||||
domains domain.List
|
||||
err error
|
||||
)
|
||||
|
||||
if len(labels) == 0 {
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
domains, err = domain.ValidateDomains(labels)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
|
||||
}
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
func isValidAddrPort(input string) bool {
|
||||
if input == "" {
|
||||
return true
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
// Package embed provides a way to embed the NetBird client directly
|
||||
// into Go programs without requiring a separate NetBird client installation.
|
||||
package embed
|
||||
|
||||
// Basic Usage:
|
||||
//
|
||||
// client, err := embed.New(embed.Options{
|
||||
// DeviceName: "my-service",
|
||||
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||
// })
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
// if err := client.Start(ctx); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// Complete HTTP Server Example:
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "context"
|
||||
// "fmt"
|
||||
// "log"
|
||||
// "net/http"
|
||||
// "os"
|
||||
// "os/signal"
|
||||
// "syscall"
|
||||
// "time"
|
||||
//
|
||||
// netbird "github.com/netbirdio/netbird/client/embed"
|
||||
// )
|
||||
//
|
||||
// func main() {
|
||||
// // Create client with setup key and device name
|
||||
// client, err := netbird.New(netbird.Options{
|
||||
// DeviceName: "http-server",
|
||||
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||
// LogOutput: io.Discard,
|
||||
// })
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // Start with timeout
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
// if err := client.Start(ctx); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // Create HTTP server
|
||||
// mux := http.NewServeMux()
|
||||
// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
// fmt.Printf("Request from %s: %s %s\n", r.RemoteAddr, r.Method, r.URL.Path)
|
||||
// fmt.Fprintf(w, "Hello from netbird!")
|
||||
// })
|
||||
//
|
||||
// // Listen on netbird network
|
||||
// l, err := client.ListenTCP(":8080")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// server := &http.Server{Handler: mux}
|
||||
// go func() {
|
||||
// if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) {
|
||||
// log.Printf("HTTP server error: %v", err)
|
||||
// }
|
||||
// }()
|
||||
//
|
||||
// log.Printf("HTTP server listening on netbird network port 8080")
|
||||
//
|
||||
// // Handle shutdown
|
||||
// stop := make(chan os.Signal, 1)
|
||||
// signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||||
// <-stop
|
||||
//
|
||||
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
// log.Printf("HTTP shutdown error: %v", err)
|
||||
// }
|
||||
// if err := client.Stop(shutdownCtx); err != nil {
|
||||
// log.Printf("Netbird shutdown error: %v", err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Complete HTTP Client Example:
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "context"
|
||||
// "fmt"
|
||||
// "io"
|
||||
// "log"
|
||||
// "os"
|
||||
// "time"
|
||||
//
|
||||
// netbird "github.com/netbirdio/netbird/client/embed"
|
||||
// )
|
||||
//
|
||||
// func main() {
|
||||
// // Create client with setup key and device name
|
||||
// client, err := netbird.New(netbird.Options{
|
||||
// DeviceName: "http-client",
|
||||
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||
// LogOutput: io.Discard,
|
||||
// })
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // Start with timeout
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// if err := client.Start(ctx); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // Create HTTP client that uses netbird network
|
||||
// httpClient := client.NewHTTPClient()
|
||||
// httpClient.Timeout = 10 * time.Second
|
||||
//
|
||||
// // Make request to server in netbird network
|
||||
// target := os.Getenv("NB_TARGET")
|
||||
// resp, err := httpClient.Get(target)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// defer resp.Body.Close()
|
||||
//
|
||||
// // Read and print response
|
||||
// body, err := io.ReadAll(resp.Body)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// fmt.Printf("Response from server: %s\n", string(body))
|
||||
//
|
||||
// // Clean shutdown
|
||||
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// if err := client.Stop(shutdownCtx); err != nil {
|
||||
// log.Printf("Netbird shutdown error: %v", err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// The package provides several methods for network operations:
|
||||
// - Dial: Creates outbound connections
|
||||
// - ListenTCP: Creates TCP listeners
|
||||
// - ListenUDP: Creates UDP listeners
|
||||
//
|
||||
// By default, the embed package uses userspace networking mode, which doesn't
|
||||
// require root/admin privileges. For production deployments, consider setting
|
||||
// appropriate config and state paths for persistence.
|
||||
@@ -1,293 +0,0 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||
var ErrClientNotStarted = errors.New("client not started")
|
||||
|
||||
// Client manages a netbird embedded client instance
|
||||
type Client struct {
|
||||
deviceName string
|
||||
config *internal.Config
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
setupKey string
|
||||
connect *internal.ConnectClient
|
||||
}
|
||||
|
||||
// Options configures a new Client
|
||||
type Options struct {
|
||||
// DeviceName is this peer's name in the network
|
||||
DeviceName string
|
||||
// SetupKey is used for authentication
|
||||
SetupKey string
|
||||
// ManagementURL overrides the default management server URL
|
||||
ManagementURL string
|
||||
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||
PreSharedKey string
|
||||
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
|
||||
LogOutput io.Writer
|
||||
// LogLevel sets the logging level (defaults to info if empty)
|
||||
LogLevel string
|
||||
// NoUserspace disables the userspace networking mode. Needs admin/root privileges
|
||||
NoUserspace bool
|
||||
// ConfigPath is the path to the netbird config file. If empty, the config will be stored in memory and not persisted.
|
||||
ConfigPath string
|
||||
// StatePath is the path to the netbird state file
|
||||
StatePath string
|
||||
// DisableClientRoutes disables the client routes
|
||||
DisableClientRoutes bool
|
||||
}
|
||||
|
||||
// New creates a new netbird embedded client
|
||||
func New(opts Options) (*Client, error) {
|
||||
if opts.LogOutput != nil {
|
||||
logrus.SetOutput(opts.LogOutput)
|
||||
}
|
||||
|
||||
if opts.LogLevel != "" {
|
||||
level, err := logrus.ParseLevel(opts.LogLevel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse log level: %w", err)
|
||||
}
|
||||
logrus.SetLevel(level)
|
||||
}
|
||||
|
||||
if !opts.NoUserspace {
|
||||
if err := os.Setenv(netstack.EnvUseNetstackMode, "true"); err != nil {
|
||||
return nil, fmt.Errorf("setenv: %w", err)
|
||||
}
|
||||
if err := os.Setenv(netstack.EnvSkipProxy, "true"); err != nil {
|
||||
return nil, fmt.Errorf("setenv: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if opts.StatePath != "" {
|
||||
// TODO: Disable state if path not provided
|
||||
if err := os.Setenv("NB_DNS_STATE_FILE", opts.StatePath); err != nil {
|
||||
return nil, fmt.Errorf("setenv: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
t := true
|
||||
var config *internal.Config
|
||||
var err error
|
||||
input := internal.ConfigInput{
|
||||
ConfigPath: opts.ConfigPath,
|
||||
ManagementURL: opts.ManagementURL,
|
||||
PreSharedKey: &opts.PreSharedKey,
|
||||
DisableServerRoutes: &t,
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
}
|
||||
if opts.ConfigPath != "" {
|
||||
config, err = internal.UpdateOrCreateConfig(input)
|
||||
} else {
|
||||
config, err = internal.CreateInMemoryConfig(input)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create config: %w", err)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
deviceName: opts.DeviceName,
|
||||
setupKey: opts.SetupKey,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins client operation and blocks until the engine has been started successfully or a startup error occurs.
|
||||
// Pass a context with a deadline to limit the time spent waiting for the engine to start.
|
||||
func (c *Client) Start(startCtx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.cancel != nil {
|
||||
return ErrClientAlreadyStarted
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
run := make(chan struct{}, 1)
|
||||
clientErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := client.Run(run); err != nil {
|
||||
clientErr <- err
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
}
|
||||
return startCtx.Err()
|
||||
case err := <-clientErr:
|
||||
return fmt.Errorf("startup: %w", err)
|
||||
case <-run:
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the client.
|
||||
// Pass a context with a deadline to limit the time spent waiting for the engine to stop.
|
||||
func (c *Client) Stop(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.connect == nil {
|
||||
return ErrClientNotStarted
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- c.connect.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.cancel = nil
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
c.cancel = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get net: %w", err)
|
||||
}
|
||||
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenTCP listens on the given address in the netbird network
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
nsnet, addr, err := c.getNet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve: %w", err)
|
||||
}
|
||||
return nsnet.ListenTCP(tcpAddr)
|
||||
}
|
||||
|
||||
// ListenUDP listens on the given address in the netbird network
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||
nsnet, addr, err := c.getNet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("split host port: %w", err)
|
||||
}
|
||||
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve: %w", err)
|
||||
}
|
||||
|
||||
return nsnet.ListenUDP(udpAddr)
|
||||
}
|
||||
|
||||
// NewHTTPClient returns a configured http.Client that uses the netbird network for requests.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) NewHTTPClient() *http.Client {
|
||||
transport := &http.Transport{
|
||||
DialContext: c.Dial,
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, netip.Addr{}, errors.New("client not started")
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, netip.Addr{}, errors.New("engine not started")
|
||||
}
|
||||
|
||||
addr, err := engine.Address()
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, fmt.Errorf("engine address: %w", err)
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, fmt.Errorf("get net: %w", err)
|
||||
}
|
||||
|
||||
return nsnet, addr, nil
|
||||
}
|
||||
@@ -10,18 +10,17 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -34,7 +33,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
@@ -48,7 +47,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||
@@ -78,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
||||
}
|
||||
|
||||
if errUsp != nil {
|
||||
|
||||
@@ -4,13 +4,12 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
Address() device.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
SetFilter(device.PacketFilter) error
|
||||
GetDevice() *device.FilteredDevice
|
||||
|
||||
@@ -30,8 +30,10 @@ type entry struct {
|
||||
}
|
||||
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
|
||||
entries aclEntries
|
||||
optionalEntries map[string][]entry
|
||||
ipsetStore *ipsetStore
|
||||
@@ -39,10 +41,12 @@ type aclManager struct {
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
||||
m := &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
routingFwChainName: routingFwChainName,
|
||||
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
@@ -75,7 +79,6 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
func (m *aclManager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
@@ -311,12 +314,9 @@ func (m *aclManager) seedInitialEntries() {
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||
|
||||
// Inbound is handled by our ACLs, the rest is dropped.
|
||||
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
|
||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
||||
}
|
||||
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ type Manager struct {
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
}
|
||||
@@ -96,36 +96,36 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
//
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||
if !destination.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@@ -147,10 +147,6 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -170,7 +166,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -200,13 +196,13 @@ func (m *Manager) AllowNetbird() error {
|
||||
}
|
||||
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
firewall.ProtocolALL,
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
@@ -222,44 +218,6 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
// not supported
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -10,17 +10,20 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -28,7 +31,7 @@ var ifaceMock = &iFaceMock{
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
AddressFunc func() wgaddr.Address
|
||||
AddressFunc func() iface.WGAddress
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Name() string {
|
||||
@@ -38,7 +41,7 @@ func (i *iFaceMock) Name() string {
|
||||
panic("NameFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Address() wgaddr.Address {
|
||||
func (i *iFaceMock) Address() iface.WGAddress {
|
||||
if i.AddressFunc != nil {
|
||||
return i.AddressFunc()
|
||||
}
|
||||
@@ -59,7 +62,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Close(nil)
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -67,12 +70,12 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
IsRange: true,
|
||||
Values: []uint16{8043, 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@@ -92,19 +95,19 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{5353}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Close(nil)
|
||||
err = manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
|
||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||
require.NoError(t, err, "failed check chain exists")
|
||||
|
||||
if ok {
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -114,10 +117,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -130,7 +136,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Close(nil)
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -138,11 +144,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []uint16{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
@@ -160,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
err = manager.Close(nil)
|
||||
err = manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
})
|
||||
}
|
||||
@@ -178,10 +184,13 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("10.20.0.1"),
|
||||
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("10.20.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("10.20.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -195,7 +204,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err := manager.Close(nil)
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "clear the manager state")
|
||||
|
||||
time.Sleep(time.Second)
|
||||
@@ -203,11 +212,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := netip.MustParseAddr("10.20.0.100")
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
@@ -15,8 +15,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@@ -24,51 +23,35 @@ import (
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
tableFilter = "filter"
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
|
||||
tableFilter = "filter"
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainPREROUTING = "PREROUTING"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
|
||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTFWD = "NETBIRD-RT-FWD"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
chainRTRDR = "NETBIRD-RT-RDR"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
matchSet = "--match-set"
|
||||
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
fwdSuffix = "_fwd"
|
||||
jumpPre = "jump-pre"
|
||||
jumpNat = "jump-nat"
|
||||
matchSet = "--match-set"
|
||||
)
|
||||
|
||||
type ruleInfo struct {
|
||||
chain string
|
||||
table string
|
||||
rule []string
|
||||
}
|
||||
|
||||
type routeFilteringRuleParams struct {
|
||||
Source firewall.Network
|
||||
Destination firewall.Network
|
||||
Sources []netip.Prefix
|
||||
Destination netip.Prefix
|
||||
Proto firewall.Protocol
|
||||
SPort *firewall.Port
|
||||
DPort *firewall.Port
|
||||
Direction firewall.RuleDirection
|
||||
Action firewall.Action
|
||||
SetName string
|
||||
}
|
||||
|
||||
type routeRules map[string][]string
|
||||
|
||||
// the ipset library currently does not support comments, so we use the name only (string)
|
||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
|
||||
type router struct {
|
||||
@@ -79,7 +62,6 @@ type router struct {
|
||||
legacyManagement bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
}
|
||||
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||
@@ -87,7 +69,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
|
||||
iptablesClient: iptablesClient,
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
@@ -117,56 +98,50 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
var source firewall.Network
|
||||
var setName string
|
||||
if len(sources) > 1 {
|
||||
source.Set = firewall.NewPrefixSet(sources)
|
||||
} else if len(sources) > 0 {
|
||||
source.Prefix = sources[0]
|
||||
setName = firewall.GenerateSetName(sources)
|
||||
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
params := routeFilteringRuleParams{
|
||||
Source: source,
|
||||
Sources: sources,
|
||||
Destination: destination,
|
||||
Proto: proto,
|
||||
SPort: sPort,
|
||||
DPort: dPort,
|
||||
Action: action,
|
||||
SetName: setName,
|
||||
}
|
||||
|
||||
rule, err := r.genRouteRuleSpec(params, sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate route rule spec: %w", err)
|
||||
}
|
||||
|
||||
rule := genRouteFilteringRuleSpec(params)
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
var err error
|
||||
if action == firewall.ActionDrop {
|
||||
// after the established rule
|
||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
|
||||
} else {
|
||||
err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
|
||||
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -181,16 +156,20 @@ func (r *router) AddRouteFiltering(
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
ruleKey := rule.ID()
|
||||
ruleKey := rule.GetRuleID()
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
setName := r.findSetNameInRule(rule)
|
||||
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return fmt.Errorf("delete route rule: %v", err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
if setName != "" {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
return fmt.Errorf("failed to remove ipset: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Debugf("route rule %s not found", ruleKey)
|
||||
@@ -201,26 +180,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) decrementSetCounter(rule []string) error {
|
||||
sets := r.findSets(rule)
|
||||
var merr *multierror.Error
|
||||
for _, setName := range sets {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) findSets(rule []string) []string {
|
||||
var sets []string
|
||||
func (r *router) findSetNameInRule(rule []string) string {
|
||||
for i, arg := range rule {
|
||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||
sets = append(sets, rule[i+3])
|
||||
return rule[i+3]
|
||||
}
|
||||
}
|
||||
return sets
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
@@ -241,8 +207,6 @@ func (r *router) deleteIpSet(setName string) error {
|
||||
if err := ipset.Destroy(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
log.Debugf("Deleted unused ipset %s", setName)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -274,14 +238,12 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
|
||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
@@ -302,7 +264,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
@@ -315,14 +277,12 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("legacy forwarding rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -345,7 +305,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
@@ -362,16 +322,12 @@ func (r *router) Reset() error {
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
r.rules = make(map[string][]string)
|
||||
|
||||
if err := r.ipsetCounter.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.cleanupDataPlaneMark(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
r.rules = make(map[string][]string)
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
@@ -387,11 +343,9 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFWDIN, tableFilter},
|
||||
{chainRTFWDOUT, tableFilter},
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTFWD, tableFilter},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainRTPRE, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
if err != nil {
|
||||
@@ -411,22 +365,16 @@ func (r *router) createContainers() error {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFWDIN, tableFilter},
|
||||
{chainRTFWDOUT, tableFilter},
|
||||
{chainRTFWD, tableFilter},
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
} {
|
||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
|
||||
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -441,57 +389,6 @@ func (r *router) createContainers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDataPlaneMark configures the fwmark for the data plane
|
||||
func (r *router) setupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
preRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePre] = preRule
|
||||
}
|
||||
|
||||
postRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePost] = postRule
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) cleanupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
if preRule, exists := r.rules[markManglePre]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePre)
|
||||
}
|
||||
}
|
||||
|
||||
if postRule, exists := r.rules[markManglePost]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePost)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) addPostroutingRules() error {
|
||||
// First rule for outbound masquerade
|
||||
rule1 := []string{
|
||||
@@ -518,6 +415,27 @@ func (r *router) addPostroutingRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) createAndSetupChain(chain string) error {
|
||||
table := r.getTableForChain(chain)
|
||||
|
||||
if err := r.iptablesClient.NewChain(table, chain); err != nil {
|
||||
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) getTableForChain(chain string) string {
|
||||
switch chain {
|
||||
case chainRTNAT:
|
||||
return tableNat
|
||||
case chainRTPRE:
|
||||
return tableMangle
|
||||
default:
|
||||
return tableFilter
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) insertEstablishedRule(chain string) error {
|
||||
establishedRule := getConntrackEstablished()
|
||||
|
||||
@@ -533,46 +451,31 @@ func (r *router) insertEstablishedRule(chain string) error {
|
||||
}
|
||||
|
||||
func (r *router) addJumpRules() error {
|
||||
// Jump to nat chain
|
||||
// Jump to NAT chain
|
||||
natRule := []string{"-j", chainRTNAT}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
||||
return fmt.Errorf("add nat jump rule: %v", err)
|
||||
}
|
||||
r.rules[jumpNatPost] = natRule
|
||||
r.rules[jumpNat] = natRule
|
||||
|
||||
// Jump to mangle prerouting chain
|
||||
// Jump to prerouting chain
|
||||
preRule := []string{"-j", chainRTPRE}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
||||
return fmt.Errorf("add mangle prerouting jump rule: %v", err)
|
||||
return fmt.Errorf("add prerouting jump rule: %v", err)
|
||||
}
|
||||
r.rules[jumpManglePre] = preRule
|
||||
|
||||
// Jump to nat prerouting chain
|
||||
rdrRule := []string{"-j", chainRTRDR}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
|
||||
return fmt.Errorf("add nat prerouting jump rule: %v", err)
|
||||
}
|
||||
r.rules[jumpNatPre] = rdrRule
|
||||
r.rules[jumpPre] = preRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) cleanJumpRules() error {
|
||||
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
|
||||
for _, ruleKey := range []string{jumpNat, jumpPre} {
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
var table, chain string
|
||||
switch ruleKey {
|
||||
case jumpNatPost:
|
||||
table = tableNat
|
||||
chain = chainPOSTROUTING
|
||||
case jumpManglePre:
|
||||
table := tableNat
|
||||
chain := chainPOSTROUTING
|
||||
if ruleKey == jumpPre {
|
||||
table = tableMangle
|
||||
chain = chainPREROUTING
|
||||
case jumpNatPre:
|
||||
table = tableNat
|
||||
chain = chainPREROUTING
|
||||
default:
|
||||
return fmt.Errorf("unknown jump rule: %s", ruleKey)
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
||||
@@ -607,32 +510,16 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
rule = append(rule,
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
)
|
||||
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply network -s: %w", err)
|
||||
}
|
||||
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply network -d: %w", err)
|
||||
}
|
||||
|
||||
rule = append(rule, sourceExp...)
|
||||
rule = append(rule, destExp...)
|
||||
rule = append(rule,
|
||||
"-s", pair.Source.String(),
|
||||
"-d", pair.Destination.String(),
|
||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||
)
|
||||
|
||||
// Ensure nat rules come first, so the mark can be overwritten.
|
||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
|
||||
// TODO: rollback ipset counter
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = rule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -644,15 +531,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("marking rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -682,152 +564,17 @@ func (r *router) updateState() {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
toDestination := rule.TranslatedAddress.String()
|
||||
switch {
|
||||
case len(rule.TranslatedPort.Values) == 0:
|
||||
// no translated port, use original port
|
||||
case len(rule.TranslatedPort.Values) == 1:
|
||||
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
|
||||
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||
// need the "/originalport" suffix to avoid dnat port randomization
|
||||
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||
}
|
||||
|
||||
proto := strings.ToLower(string(rule.Protocol))
|
||||
|
||||
rules := make(map[string]ruleInfo, 3)
|
||||
|
||||
// DNAT rule
|
||||
dnatRule := []string{
|
||||
"!", "-i", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-j", "DNAT",
|
||||
"--to-destination", toDestination,
|
||||
}
|
||||
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
|
||||
rules[ruleKey+dnatSuffix] = ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTRDR,
|
||||
rule: dnatRule,
|
||||
}
|
||||
|
||||
// SNAT rule
|
||||
snatRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-d", rule.TranslatedAddress.String(),
|
||||
"-j", "MASQUERADE",
|
||||
}
|
||||
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||
rules[ruleKey+snatSuffix] = ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTNAT,
|
||||
rule: snatRule,
|
||||
}
|
||||
|
||||
// Forward filtering rule, if fwd policy is DROP
|
||||
forwardRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-d", rule.TranslatedAddress.String(),
|
||||
"-j", "ACCEPT",
|
||||
}
|
||||
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||
rules[ruleKey+fwdSuffix] = ruleInfo{
|
||||
table: tableFilter,
|
||||
chain: chainRTFWDOUT,
|
||||
rule: forwardRule,
|
||||
}
|
||||
|
||||
for key, ruleInfo := range rules {
|
||||
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
||||
log.Errorf("rollback failed: %v", rollbackErr)
|
||||
}
|
||||
return nil, fmt.Errorf("add rule %s: %w", key, err)
|
||||
}
|
||||
r.rules[key] = ruleInfo.rule
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (r *router) rollbackRules(rules map[string]ruleInfo) error {
|
||||
var merr *multierror.Error
|
||||
for key, ruleInfo := range rules {
|
||||
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
|
||||
// On rollback error, add to rules map for next cleanup
|
||||
r.rules[key] = ruleInfo.rule
|
||||
}
|
||||
}
|
||||
if merr != nil {
|
||||
r.updateState()
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
|
||||
var merr *multierror.Error
|
||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleKey+dnatSuffix)
|
||||
}
|
||||
|
||||
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleKey+snatSuffix)
|
||||
}
|
||||
|
||||
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleKey+fwdSuffix)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) {
|
||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
var rule []string
|
||||
|
||||
sourceExp, err := r.applyNetwork("-s", params.Source, sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply network -s: %w", err)
|
||||
|
||||
}
|
||||
destExp, err := r.applyNetwork("-d", params.Destination, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply network -d: %w", err)
|
||||
if params.SetName != "" {
|
||||
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
|
||||
} else if len(params.Sources) > 0 {
|
||||
source := params.Sources[0]
|
||||
rule = append(rule, "-s", source.String())
|
||||
}
|
||||
|
||||
rule = append(rule, sourceExp...)
|
||||
rule = append(rule, destExp...)
|
||||
rule = append(rule, "-d", params.Destination.String())
|
||||
|
||||
if params.Proto != firewall.ProtocolALL {
|
||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||
@@ -837,47 +584,7 @@ func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []net
|
||||
|
||||
rule = append(rule, "-j", actionToStr(params.Action))
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||
direction := "src"
|
||||
if flag == "-d" {
|
||||
direction = "dst"
|
||||
}
|
||||
|
||||
if network.IsSet() {
|
||||
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
|
||||
}
|
||||
if network.IsPrefix() {
|
||||
return []string{flag, network.Prefix.String()}, nil
|
||||
}
|
||||
|
||||
// nolint:nilnil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
// TODO: Implement IPv6 support
|
||||
if prefix.Addr().Is6() {
|
||||
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||
continue
|
||||
}
|
||||
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return rule
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
|
||||
@@ -39,16 +39,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
}()
|
||||
|
||||
// Now 5 rules:
|
||||
// 1. established rule forward in
|
||||
// 2. estbalished rule forward out
|
||||
// 3. jump rule to POST nat chain
|
||||
// 4. jump rule to PRE mangle chain
|
||||
// 5. jump rule to PRE nat chain
|
||||
// 6. static outbound masquerade rule
|
||||
// 7. static return masquerade rule
|
||||
// 8. mangle prerouting mark rule
|
||||
// 9. mangle postrouting mark rule
|
||||
require.Len(t, manager.rules, 9, "should have created rules map")
|
||||
// 1. established rule in forward chain
|
||||
// 2. jump rule to NAT chain
|
||||
// 3. jump rule to PRE chain
|
||||
// 4. static outbound masquerade rule
|
||||
// 5. static return masquerade rule
|
||||
require.Len(t, manager.rules, 5, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
@@ -60,8 +56,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
ID: "abc",
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")},
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
Masquerade: true,
|
||||
}
|
||||
|
||||
@@ -332,44 +328,38 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.ID()]
|
||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
|
||||
// Log the internal rule
|
||||
t.Logf("Internal rule: %v", rule)
|
||||
|
||||
// Check if the rule exists in iptables
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
|
||||
assert.NoError(t, err, "Failed to check rule existence")
|
||||
assert.True(t, exists, "Rule not found in iptables")
|
||||
|
||||
var source firewall.Network
|
||||
if len(tt.sources) > 1 {
|
||||
source.Set = firewall.NewPrefixSet(tt.sources)
|
||||
} else if len(tt.sources) > 0 {
|
||||
source.Prefix = tt.sources[0]
|
||||
}
|
||||
// Verify rule content
|
||||
params := routeFilteringRuleParams{
|
||||
Source: source,
|
||||
Destination: firewall.Network{Prefix: tt.destination},
|
||||
Sources: tt.sources,
|
||||
Destination: tt.destination,
|
||||
Proto: tt.proto,
|
||||
SPort: tt.sPort,
|
||||
DPort: tt.dPort,
|
||||
Action: tt.action,
|
||||
SetName: "",
|
||||
}
|
||||
|
||||
expectedRule, err := r.genRouteRuleSpec(params, nil)
|
||||
require.NoError(t, err, "Failed to generate expected rule spec")
|
||||
expectedRule := genRouteFilteringRuleSpec(params)
|
||||
|
||||
if tt.expectSet {
|
||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||
expectedRule, err = r.genRouteRuleSpec(params, nil)
|
||||
require.NoError(t, err, "Failed to generate expected rule spec with set")
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
params.SetName = setName
|
||||
expectedRule = genRouteFilteringRuleSpec(params)
|
||||
|
||||
// Check if the set was created
|
||||
_, exists := r.ipsetCounter.Get(setName)
|
||||
@@ -384,62 +374,3 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSetNameInRule(t *testing.T) {
|
||||
r := &router{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
rule []string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Basic rule with two sets",
|
||||
rule: []string{
|
||||
"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src",
|
||||
"-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT",
|
||||
},
|
||||
expected: []string{"nb-2e5a2a05", "nb-349ae051"},
|
||||
},
|
||||
{
|
||||
name: "No sets",
|
||||
rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"},
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Multiple sets with different positions",
|
||||
rule: []string{
|
||||
"-m", "set", "--match-set", "set1", "src", "-p", "tcp",
|
||||
"-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT",
|
||||
},
|
||||
expected: []string{"set1", "set-abc123"},
|
||||
},
|
||||
{
|
||||
name: "Boundary case - sequence appears at end",
|
||||
rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"},
|
||||
expected: []string{"final-set"},
|
||||
},
|
||||
{
|
||||
name: "Incomplete pattern - missing set name",
|
||||
rule: []string{"-p", "tcp", "-m", "set", "--match-set"},
|
||||
expected: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := r.findSets(tc.rule)
|
||||
|
||||
if len(result) != len(tc.expected) {
|
||||
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
||||
return
|
||||
}
|
||||
|
||||
for i, set := range result {
|
||||
if set != tc.expected[i] {
|
||||
t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,6 @@ type Rule struct {
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) ID() string {
|
||||
func (r *Rule) GetRuleID() string {
|
||||
return r.ruleID
|
||||
}
|
||||
|
||||
@@ -4,20 +4,21 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress iface.WGAddress `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
return i.NameStr
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Address() wgaddr.Address {
|
||||
func (i *InterfaceState) Address() device.WGAddress {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
@@ -61,7 +62,7 @@ func (s *ShutdownState) Cleanup() error {
|
||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||
}
|
||||
|
||||
if err := ipt.Close(nil); err != nil {
|
||||
if err := ipt.Reset(nil); err != nil {
|
||||
return fmt.Errorf("reset iptables manager: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -23,8 +26,8 @@ const (
|
||||
// Each firewall type for different OS can use different type
|
||||
// of the properties to hold data of the created rule
|
||||
type Rule interface {
|
||||
// ID returns the rule id
|
||||
ID() string
|
||||
// GetRuleID returns the rule id
|
||||
GetRuleID() string
|
||||
}
|
||||
|
||||
// RuleDirection is the traffic direction which a rule is applied
|
||||
@@ -40,18 +43,6 @@ const (
|
||||
// Action is the action to be taken on a rule
|
||||
type Action int
|
||||
|
||||
// String returns the string representation of the action
|
||||
func (a Action) String() string {
|
||||
switch a {
|
||||
case ActionAccept:
|
||||
return "accept"
|
||||
case ActionDrop:
|
||||
return "drop"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// ActionAccept is the action to accept a packet
|
||||
ActionAccept Action = iota
|
||||
@@ -59,33 +50,6 @@ const (
|
||||
ActionDrop
|
||||
)
|
||||
|
||||
// Network is a rule destination, either a set or a prefix
|
||||
type Network struct {
|
||||
Set Set
|
||||
Prefix netip.Prefix
|
||||
}
|
||||
|
||||
// String returns the string representation of the destination
|
||||
func (d Network) String() string {
|
||||
if d.Prefix.IsValid() {
|
||||
return d.Prefix.String()
|
||||
}
|
||||
if d.IsSet() {
|
||||
return d.Set.HashedName()
|
||||
}
|
||||
return "<invalid network>"
|
||||
}
|
||||
|
||||
// IsSet returns true if the destination is a set
|
||||
func (d Network) IsSet() bool {
|
||||
return d.Set != Set{}
|
||||
}
|
||||
|
||||
// IsPrefix returns true if the destination is a valid prefix
|
||||
func (d Network) IsPrefix() bool {
|
||||
return d.Prefix.IsValid()
|
||||
}
|
||||
|
||||
// Manager is the high level abstraction of a firewall manager
|
||||
//
|
||||
// It declares methods which handle actions required by the
|
||||
@@ -101,13 +65,13 @@ type Manager interface {
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]Rule, error)
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@@ -116,16 +80,7 @@ type Manager interface {
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
IsStateful() bool
|
||||
|
||||
AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination Network,
|
||||
proto Protocol,
|
||||
sPort, dPort *Port,
|
||||
action Action,
|
||||
) (Rule, error)
|
||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
DeleteRouteRule(rule Rule) error
|
||||
@@ -139,26 +94,13 @@ type Manager interface {
|
||||
// SetLegacyManagement sets the legacy management mode
|
||||
SetLegacyManagement(legacy bool) error
|
||||
|
||||
// Close closes the firewall manager
|
||||
Close(stateManager *statemanager.Manager) error
|
||||
// Reset firewall to the default state
|
||||
Reset(stateManager *statemanager.Manager) error
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
SetLogLevel(log.Level)
|
||||
|
||||
EnableRouting() error
|
||||
|
||||
DisableRouting() error
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
AddDNATRule(ForwardRule) (Rule, error)
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
DeleteDNATRule(Rule) error
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
@@ -193,6 +135,22 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
||||
func GenerateSetName(sources []netip.Prefix) string {
|
||||
// sort for consistent naming
|
||||
SortPrefixes(sources)
|
||||
|
||||
var sourcesStr strings.Builder
|
||||
for _, src := range sources {
|
||||
sourcesStr.WriteString(src.String())
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(sourcesStr.String()))
|
||||
shortHash := hex.EncodeToString(hash[:])[:8]
|
||||
|
||||
return fmt.Sprintf("nb-%s", shortHash)
|
||||
}
|
||||
|
||||
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
||||
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
if len(prefixes) == 0 {
|
||||
|
||||
@@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) {
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.NewPrefixSet(prefixes1)
|
||||
result2 := manager.NewPrefixSet(prefixes2)
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
||||
@@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) {
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
|
||||
result := manager.NewPrefixSet(prefixes)
|
||||
result := manager.GenerateSetName(prefixes)
|
||||
|
||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName())
|
||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
|
||||
if err != nil {
|
||||
t.Fatalf("Error matching regex: %v", err)
|
||||
}
|
||||
@@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
||||
result1 := manager.NewPrefixSet([]netip.Prefix{})
|
||||
result2 := manager.NewPrefixSet([]netip.Prefix{})
|
||||
result1 := manager.GenerateSetName([]netip.Prefix{})
|
||||
result2 := manager.GenerateSetName([]netip.Prefix{})
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
||||
@@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) {
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.NewPrefixSet(prefixes1)
|
||||
result2 := manager.NewPrefixSet(prefixes2)
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// ForwardRule todo figure out better place to this to avoid circular imports
|
||||
type ForwardRule struct {
|
||||
Protocol Protocol
|
||||
DestinationPort Port
|
||||
TranslatedAddress netip.Addr
|
||||
TranslatedPort Port
|
||||
}
|
||||
|
||||
func (r ForwardRule) ID() string {
|
||||
id := fmt.Sprintf("%s;%s;%s;%s",
|
||||
r.Protocol,
|
||||
r.DestinationPort.String(),
|
||||
r.TranslatedAddress.String(),
|
||||
r.TranslatedPort.String())
|
||||
return id
|
||||
}
|
||||
|
||||
func (r ForwardRule) String() string {
|
||||
return fmt.Sprintf("protocol: %s, destinationPort: %s, translatedAddress: %s, translatedPort: %s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String())
|
||||
}
|
||||
@@ -1,12 +1,30 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Protocol is the protocol of the port
|
||||
type Protocol string
|
||||
|
||||
const (
|
||||
// ProtocolTCP is the TCP protocol
|
||||
ProtocolTCP Protocol = "tcp"
|
||||
|
||||
// ProtocolUDP is the UDP protocol
|
||||
ProtocolUDP Protocol = "udp"
|
||||
|
||||
// ProtocolICMP is the ICMP protocol
|
||||
ProtocolICMP Protocol = "icmp"
|
||||
|
||||
// ProtocolALL cover all supported protocols
|
||||
ProtocolALL Protocol = "all"
|
||||
|
||||
// ProtocolUnknown unknown protocol
|
||||
ProtocolUnknown Protocol = "unknown"
|
||||
)
|
||||
|
||||
// Port of the address for firewall rule
|
||||
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
|
||||
type Port struct {
|
||||
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
||||
IsRange bool
|
||||
@@ -15,25 +33,6 @@ type Port struct {
|
||||
Values []uint16
|
||||
}
|
||||
|
||||
func NewPort(ports ...int) (*Port, error) {
|
||||
if len(ports) == 0 {
|
||||
return nil, fmt.Errorf("no port provided")
|
||||
}
|
||||
|
||||
ports16 := make([]uint16, len(ports))
|
||||
for i, port := range ports {
|
||||
if port < 1 || port > 65535 {
|
||||
return nil, fmt.Errorf("invalid port number: %d (must be between 1-65535)", port)
|
||||
}
|
||||
ports16[i] = uint16(port)
|
||||
}
|
||||
|
||||
return &Port{
|
||||
IsRange: len(ports) > 1,
|
||||
Values: ports16,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// String interface implementation
|
||||
func (p *Port) String() string {
|
||||
var ports string
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
package manager
|
||||
|
||||
// Protocol is the protocol of the port
|
||||
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
|
||||
type Protocol string
|
||||
|
||||
const (
|
||||
// ProtocolTCP is the TCP protocol
|
||||
ProtocolTCP Protocol = "tcp"
|
||||
|
||||
// ProtocolUDP is the UDP protocol
|
||||
ProtocolUDP Protocol = "udp"
|
||||
|
||||
// ProtocolICMP is the ICMP protocol
|
||||
ProtocolICMP Protocol = "icmp"
|
||||
|
||||
// ProtocolALL cover all supported protocols
|
||||
ProtocolALL Protocol = "all"
|
||||
)
|
||||
@@ -1,13 +1,15 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type RouterPair struct {
|
||||
ID route.ID
|
||||
Source Network
|
||||
Destination Network
|
||||
Source netip.Prefix
|
||||
Destination netip.Prefix
|
||||
Masquerade bool
|
||||
Inverse bool
|
||||
}
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
type Set struct {
|
||||
hash [4]byte
|
||||
comment string
|
||||
}
|
||||
|
||||
// String returns the string representation of the set: hashed name and comment
|
||||
func (h Set) String() string {
|
||||
if h.comment == "" {
|
||||
return h.HashedName()
|
||||
}
|
||||
return h.HashedName() + ": " + h.comment
|
||||
}
|
||||
|
||||
// HashedName returns the string representation of the hash
|
||||
func (h Set) HashedName() string {
|
||||
return fmt.Sprintf(
|
||||
"nb-%s",
|
||||
hex.EncodeToString(h.hash[:]),
|
||||
)
|
||||
}
|
||||
|
||||
// Comment returns the comment of the set
|
||||
func (h Set) Comment() string {
|
||||
return h.comment
|
||||
}
|
||||
|
||||
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
||||
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
||||
// sort for consistent naming
|
||||
SortPrefixes(prefixes)
|
||||
|
||||
hash := sha256.New()
|
||||
for _, src := range prefixes {
|
||||
bytes, err := src.MarshalBinary()
|
||||
if err != nil {
|
||||
log.Warnf("failed to marshal prefix %s: %v", src, err)
|
||||
}
|
||||
hash.Write(bytes)
|
||||
}
|
||||
var set Set
|
||||
copy(set.hash[:], hash.Sum(nil)[:4])
|
||||
|
||||
return set
|
||||
}
|
||||
|
||||
// NewDomainSet generates a unique name for an ipset based on the given domains.
|
||||
func NewDomainSet(domains domain.List) Set {
|
||||
slices.Sort(domains)
|
||||
|
||||
hash := sha256.New()
|
||||
for _, d := range domains {
|
||||
hash.Write([]byte(d.PunycodeString()))
|
||||
}
|
||||
set := Set{
|
||||
comment: domains.SafeString(),
|
||||
}
|
||||
copy(set.hash[:], hash.Sum(nil)[:4])
|
||||
|
||||
return set
|
||||
}
|
||||
@@ -25,10 +25,9 @@ const (
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
|
||||
// filter chains contains the rules that jump to the rules chains
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNamePrerouting = "netbird-rt-prerouting"
|
||||
|
||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
@@ -85,13 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *AclManager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var ipset *nftables.Set
|
||||
if ipsetName != "" {
|
||||
@@ -103,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
newRules := make([]firewall.Rule, 0, 2)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -128,7 +127,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.ID())
|
||||
delete(m.rules, r.GetRuleID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
@@ -142,7 +141,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.ID())
|
||||
delete(m.rules, r.GetRuleID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
@@ -177,7 +176,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(m.rules, r.ID())
|
||||
delete(m.rules, r.GetRuleID())
|
||||
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
||||
|
||||
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
||||
@@ -257,6 +256,7 @@ func (m *AclManager) addIOFiltering(
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
comment string,
|
||||
) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
@@ -338,7 +338,7 @@ func (m *AclManager) addIOFiltering(
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
userData := []byte(ruleId)
|
||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
||||
|
||||
chain := m.chainInputRules
|
||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||
@@ -463,15 +463,13 @@ func (m *AclManager) createDefaultChains() (err error) {
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
// Chain is created by route manager
|
||||
// TODO: move creation to a common place
|
||||
m.chainPrerouting = &nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNamePrerouting,
|
||||
Table: m.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
}
|
||||
})
|
||||
|
||||
m.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ const (
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
// We only need to record minimal interface state for potential recreation.
|
||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||
// cleanup using Close() without needing to store specific rules.
|
||||
// cleanup using Reset() without needing to store specific rules.
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
@@ -113,13 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -129,25 +129,25 @@ func (m *Manager) AddPeerFiltering(
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||
}
|
||||
|
||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||
if !destination.Addr().Is4() {
|
||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
@@ -170,10 +170,6 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -245,8 +241,8 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
}
|
||||
|
||||
// Close closes the firewall manager
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -327,20 +323,6 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
// not supported
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
@@ -352,30 +334,6 @@ func (m *Manager) Flush() error {
|
||||
return m.aclManager.Flush()
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@ package nftables
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
@@ -15,17 +16,20 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
var ifaceMock = &iFaceMock{
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -33,7 +37,7 @@ var ifaceMock = &iFaceMock{
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMock struct {
|
||||
NameFunc func() string
|
||||
AddressFunc func() wgaddr.Address
|
||||
AddressFunc func() iface.WGAddress
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Name() string {
|
||||
@@ -43,7 +47,7 @@ func (i *iFaceMock) Name() string {
|
||||
panic("NameFunc is not set")
|
||||
}
|
||||
|
||||
func (i *iFaceMock) Address() wgaddr.Address {
|
||||
func (i *iFaceMock) Address() iface.WGAddress {
|
||||
if i.AddressFunc != nil {
|
||||
return i.AddressFunc()
|
||||
}
|
||||
@@ -61,16 +65,16 @@ func TestNftablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
err = manager.Close(nil)
|
||||
err = manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := netip.MustParseAddr("100.96.0.1").Unmap()
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
@@ -105,6 +109,8 @@ func TestNftablesManager(t *testing.T) {
|
||||
}
|
||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
expectedExprs2 := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
@@ -126,7 +132,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip.AsSlice(),
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
@@ -156,7 +162,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
// established rule remains
|
||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||
|
||||
err = manager.Close(nil)
|
||||
err = manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
}
|
||||
|
||||
@@ -165,10 +171,13 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
NameFunc: func() string {
|
||||
return "lo"
|
||||
},
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.96.0.1"),
|
||||
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.96.0.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.96.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -182,17 +191,17 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Close(nil); err != nil {
|
||||
if err := manager.Reset(nil); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := netip.MustParseAddr("10.20.0.100")
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@@ -265,7 +274,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Close(nil)
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
@@ -273,14 +282,13 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
ip := netip.MustParseAddr("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []uint16{443}},
|
||||
@@ -289,8 +297,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
pair := fw.RouterPair{
|
||||
Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
Masquerade: true,
|
||||
}
|
||||
err = manager.AddNatRule(pair)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -38,7 +38,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
// need fw manager to init both acl mgr and router for all chains to be present
|
||||
manager, err := Create(ifaceMock)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
@@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
|
||||
// Build CIDR matching expressions
|
||||
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
||||
|
||||
// Combine all expressions in the correct order
|
||||
// nolint:gocritic
|
||||
@@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := 0
|
||||
for _, chain := range rtr.chains {
|
||||
if chain.Name == chainNameManglePrerouting {
|
||||
if chain.Name == chainNamePrerouting {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
@@ -127,7 +127,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := Create(ifaceMock)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
@@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
// Verify the rule was added
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
found := false
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
require.NoError(t, err, "should list rules")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
@@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
// Verify the rule was removed
|
||||
found = false
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
require.NoError(t, err, "should list rules after removal")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
|
||||
t.Cleanup(func() {
|
||||
@@ -319,7 +319,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
})
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.ID()]
|
||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
|
||||
t.Log("Internal rule expressions:")
|
||||
@@ -336,7 +336,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
var nftRule *nftables.Rule
|
||||
for _, rule := range rules {
|
||||
if string(rule.UserData) == ruleKey.ID() {
|
||||
if string(rule.UserData) == ruleKey.GetRuleID() {
|
||||
nftRule = rule
|
||||
break
|
||||
}
|
||||
@@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
set, err := r.createIpSet(setName, tt.sources)
|
||||
if err != nil {
|
||||
t.Logf("Failed to create IP set: %v", err)
|
||||
printNftSets()
|
||||
@@ -595,20 +595,16 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
||||
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
|
||||
payloadFound = true
|
||||
}
|
||||
case *expr.Range:
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
fromPort := binary.BigEndian.Uint16(ex.FromData)
|
||||
toPort := binary.BigEndian.Uint16(ex.ToData)
|
||||
if fromPort == port.Values[0] && toPort == port.Values[1] {
|
||||
case *expr.Cmp:
|
||||
if port.IsRange {
|
||||
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
|
||||
portMatchFound = true
|
||||
}
|
||||
}
|
||||
case *expr.Cmp:
|
||||
if !port.IsRange {
|
||||
} else {
|
||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
|
||||
portValue := binary.BigEndian.Uint16(ex.Data)
|
||||
for _, p := range port.Values {
|
||||
if p == portValue {
|
||||
if uint16(p) == portValue {
|
||||
portMatchFound = true
|
||||
break
|
||||
}
|
||||
|
||||
@@ -16,6 +16,6 @@ type Rule struct {
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) ID() string {
|
||||
func (r *Rule) GetRuleID() string {
|
||||
return r.ruleID
|
||||
}
|
||||
|
||||
@@ -3,20 +3,21 @@ package nftables
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
NameStr string `json:"name"`
|
||||
WGAddress iface.WGAddress `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
return i.NameStr
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Address() wgaddr.Address {
|
||||
func (i *InterfaceState) Address() device.WGAddress {
|
||||
return i.WGAddress
|
||||
}
|
||||
|
||||
@@ -38,7 +39,7 @@ func (s *ShutdownState) Cleanup() error {
|
||||
return fmt.Errorf("create nftables manager: %w", err)
|
||||
}
|
||||
|
||||
if err := nft.Close(nil); err != nil {
|
||||
if err := nft.Reset(nil); err != nil {
|
||||
return fmt.Errorf("reset nftables manager: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@ var (
|
||||
Name: "Insert Forwarding IPV4 Rule",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Masquerade: false,
|
||||
},
|
||||
},
|
||||
@@ -24,8 +24,8 @@ var (
|
||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
@@ -40,8 +40,8 @@ var (
|
||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||
InputPair: firewall.RouterPair{
|
||||
ID: "zxa",
|
||||
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
||||
Masquerade: true,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -4,36 +4,39 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
@@ -45,7 +48,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
return m.nativeFirewall.Reset(stateManager)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,13 +3,13 @@ package uspfilter
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -21,28 +21,31 @@ const (
|
||||
firewallRuleName = "Netbird"
|
||||
)
|
||||
|
||||
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||
func (m *Manager) Close(*statemanager.Manager) error {
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset(*statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
if m.udpTracker != nil {
|
||||
m.udpTracker.Close()
|
||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.icmpTracker != nil {
|
||||
m.icmpTracker.Close()
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if m.tcpTracker != nil {
|
||||
m.tcpTracker.Close()
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
||||
}
|
||||
|
||||
if fwder := m.forwarder.Load(); fwder != nil {
|
||||
fwder.Stop()
|
||||
if m.forwarder != nil {
|
||||
m.forwarder.Stop()
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
|
||||
@@ -3,14 +3,14 @@ package common
|
||||
import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() wgaddr.Address
|
||||
Address() iface.WGAddress
|
||||
GetWGDevice() *wgdevice.Device
|
||||
GetDevice() *device.FilteredDevice
|
||||
}
|
||||
|
||||
@@ -1,27 +1,20 @@
|
||||
// common.go
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
FlowId uuid.UUID
|
||||
Direction nftypes.Direction
|
||||
SourceIP netip.Addr
|
||||
DestIP netip.Addr
|
||||
lastSeen atomic.Int64
|
||||
PacketsTx atomic.Uint64
|
||||
PacketsRx atomic.Uint64
|
||||
BytesTx atomic.Uint64
|
||||
BytesRx atomic.Uint64
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
||||
}
|
||||
|
||||
// these small methods will be inlined by the compiler
|
||||
@@ -31,17 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
||||
b.lastSeen.Store(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// UpdateCounters safely updates the packet and byte counters
|
||||
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
|
||||
if direction == nftypes.Egress {
|
||||
b.PacketsTx.Add(1)
|
||||
b.BytesTx.Add(uint64(bytes))
|
||||
} else {
|
||||
b.PacketsRx.Add(1)
|
||||
b.BytesRx.Add(uint64(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
// GetLastSeen safely gets the last seen timestamp
|
||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||
return time.Unix(0, b.lastSeen.Load())
|
||||
@@ -53,14 +35,92 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
|
||||
return time.Since(lastSeen) > timeout
|
||||
}
|
||||
|
||||
// IPAddr is a fixed-size IP address to avoid allocations
|
||||
type IPAddr [16]byte
|
||||
|
||||
// MakeIPAddr creates an IPAddr from net.IP
|
||||
func MakeIPAddr(ip net.IP) (addr IPAddr) {
|
||||
// Optimization: check for v4 first as it's more common
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
copy(addr[12:], ip4)
|
||||
} else {
|
||||
copy(addr[:], ip.To16())
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// ConnKey uniquely identifies a connection
|
||||
type ConnKey struct {
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
SrcIP IPAddr
|
||||
DstIP IPAddr
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
func (c ConnKey) String() string {
|
||||
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||
// makeConnKey creates a connection key
|
||||
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
||||
return ConnKey{
|
||||
SrcIP: MakeIPAddr(srcIP),
|
||||
DstIP: MakeIPAddr(dstIP),
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateIPs checks if IPs match without allocation
|
||||
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
|
||||
if ip4 := pktIP.To4(); ip4 != nil {
|
||||
// Compare IPv4 addresses (last 4 bytes)
|
||||
for i := 0; i < 4; i++ {
|
||||
if connIP[12+i] != ip4[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
// Compare full IPv6 addresses
|
||||
ip6 := pktIP.To16()
|
||||
for i := 0; i < 16; i++ {
|
||||
if connIP[i] != ip6[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
|
||||
type PreallocatedIPs struct {
|
||||
sync.Pool
|
||||
}
|
||||
|
||||
// NewPreallocatedIPs creates a new IP pool
|
||||
func NewPreallocatedIPs() *PreallocatedIPs {
|
||||
return &PreallocatedIPs{
|
||||
Pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
ip := make(net.IP, 16)
|
||||
return &ip
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves an IP from the pool
|
||||
func (p *PreallocatedIPs) Get() net.IP {
|
||||
return *p.Pool.Get().(*net.IP)
|
||||
}
|
||||
|
||||
// Put returns an IP to the pool
|
||||
func (p *PreallocatedIPs) Put(ip net.IP) {
|
||||
p.Pool.Put(&ip)
|
||||
}
|
||||
|
||||
// copyIP copies an IP address efficiently
|
||||
func copyIP(dst, src net.IP) {
|
||||
if len(src) == 16 {
|
||||
copy(dst, src)
|
||||
} else {
|
||||
// Handle IPv4
|
||||
copy(dst[12:], src.To4())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,66 +1,94 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
func BenchmarkIPOperations(b *testing.B) {
|
||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = MakeIPAddr(ip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ValidateIPs", func(b *testing.B) {
|
||||
ip1 := net.ParseIP("192.168.1.1")
|
||||
ip2 := net.ParseIP("192.168.1.1")
|
||||
addr := MakeIPAddr(ip1)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ValidateIPs(addr, ip2)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IPPool", func(b *testing.B) {
|
||||
pool := NewPreallocatedIPs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ip := pool.Get()
|
||||
pool.Put(ip)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// Memory pressure tests
|
||||
func BenchmarkMemoryPressure(b *testing.B) {
|
||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
srcIPs := make([]netip.Addr, 100)
|
||||
dstIPs := make([]netip.Addr, 100)
|
||||
srcIPs := make([]net.IP, 100)
|
||||
dstIPs := make([]net.IP, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
srcIdx := i % len(srcIPs)
|
||||
dstIdx := (i + 1) % len(dstIPs)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
|
||||
|
||||
// Simulate some valid inbound packets
|
||||
if i%3 == 0 {
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Generate different IPs
|
||||
srcIPs := make([]netip.Addr, 100)
|
||||
dstIPs := make([]netip.Addr, 100)
|
||||
srcIPs := make([]net.IP, 100)
|
||||
dstIPs := make([]net.IP, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
srcIdx := i % len(srcIPs)
|
||||
dstIdx := (i + 1) % len(dstIPs)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
|
||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
|
||||
|
||||
// Simulate some valid inbound packets
|
||||
if i%3 == 0 {
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
|
||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -20,28 +15,22 @@ const (
|
||||
DefaultICMPTimeout = 30 * time.Second
|
||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||
ICMPCleanupInterval = 15 * time.Second
|
||||
|
||||
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
|
||||
// which includes the IP header (20 bytes) and transport header (8 bytes)
|
||||
MaxICMPPayloadLength = 28
|
||||
)
|
||||
|
||||
// ICMPConnKey uniquely identifies an ICMP connection
|
||||
type ICMPConnKey struct {
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
ID uint16
|
||||
}
|
||||
|
||||
func (i ICMPConnKey) String() string {
|
||||
return fmt.Sprintf("%s → %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||
// Supports both IPv4 and IPv6
|
||||
SrcIP [16]byte
|
||||
DstIP [16]byte
|
||||
Sequence uint16 // ICMP sequence number
|
||||
ID uint16 // ICMP identifier
|
||||
}
|
||||
|
||||
// ICMPConnTrack represents an ICMP connection state
|
||||
type ICMPConnTrack struct {
|
||||
BaseConnTrack
|
||||
ICMPType uint8
|
||||
ICMPCode uint8
|
||||
Sequence uint16
|
||||
ID uint16
|
||||
}
|
||||
|
||||
// ICMPTracker manages ICMP connection states
|
||||
@@ -50,302 +39,131 @@ type ICMPTracker struct {
|
||||
connections map[ICMPConnKey]*ICMPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
|
||||
type ICMPInfo struct {
|
||||
TypeCode layers.ICMPv4TypeCode
|
||||
PayloadData [MaxICMPPayloadLength]byte
|
||||
// actual length of valid data
|
||||
PayloadLen int
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer for lazy evaluation in log messages
|
||||
func (info ICMPInfo) String() string {
|
||||
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
|
||||
if origInfo := info.parseOriginalPacket(); origInfo != "" {
|
||||
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
|
||||
}
|
||||
}
|
||||
|
||||
return info.TypeCode.String()
|
||||
}
|
||||
|
||||
// isErrorMessage returns true if this ICMP type carries original packet info
|
||||
func (info ICMPInfo) isErrorMessage() bool {
|
||||
typ := info.TypeCode.Type()
|
||||
return typ == 3 || // Destination Unreachable
|
||||
typ == 5 || // Redirect
|
||||
typ == 11 || // Time Exceeded
|
||||
typ == 12 // Parameter Problem
|
||||
}
|
||||
|
||||
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
||||
func (info ICMPInfo) parseOriginalPacket() string {
|
||||
if info.PayloadLen < MaxICMPPayloadLength {
|
||||
return ""
|
||||
}
|
||||
|
||||
// TODO: handle IPv6
|
||||
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
protocol := info.PayloadData[9]
|
||||
srcIP := net.IP(info.PayloadData[12:16])
|
||||
dstIP := net.IP(info.PayloadData[16:20])
|
||||
|
||||
transportData := info.PayloadData[20:]
|
||||
|
||||
switch nftypes.Protocol(protocol) {
|
||||
case nftypes.TCP:
|
||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
case nftypes.UDP:
|
||||
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
case nftypes.ICMP:
|
||||
icmpType := transportData[0]
|
||||
icmpCode := transportData[1]
|
||||
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
|
||||
|
||||
default:
|
||||
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
|
||||
}
|
||||
done chan struct{}
|
||||
ipPool *PreallocatedIPs
|
||||
}
|
||||
|
||||
// NewICMPTracker creates a new ICMP connection tracker
|
||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultICMPTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &ICMPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
flowLogger: flowLogger,
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
go tracker.cleanupRoutine()
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
|
||||
key := ICMPConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
ID: id,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
return key, true
|
||||
}
|
||||
|
||||
return key, false
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound ICMP connection
|
||||
func (t *ICMPTracker) TrackOutbound(
|
||||
srcIP netip.Addr,
|
||||
dstIP netip.Addr,
|
||||
id uint16,
|
||||
typecode layers.ICMPv4TypeCode,
|
||||
payload []byte,
|
||||
size int,
|
||||
) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackInbound(
|
||||
srcIP netip.Addr,
|
||||
dstIP netip.Addr,
|
||||
id uint16,
|
||||
typecode layers.ICMPv4TypeCode,
|
||||
ruleId []byte,
|
||||
payload []byte,
|
||||
size int,
|
||||
) {
|
||||
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||
func (t *ICMPTracker) track(
|
||||
srcIP netip.Addr,
|
||||
dstIP netip.Addr,
|
||||
id uint16,
|
||||
typecode layers.ICMPv4TypeCode,
|
||||
direction nftypes.Direction,
|
||||
ruleId []byte,
|
||||
payload []byte,
|
||||
size int,
|
||||
) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
|
||||
typ, code := typecode.Type(), typecode.Code()
|
||||
icmpInfo := ICMPInfo{
|
||||
TypeCode: typecode,
|
||||
}
|
||||
if len(payload) > 0 {
|
||||
icmpInfo.PayloadLen = len(payload)
|
||||
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
|
||||
icmpInfo.PayloadLen = MaxICMPPayloadLength
|
||||
}
|
||||
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
|
||||
}
|
||||
|
||||
// non echo requests don't need tracking
|
||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||
return
|
||||
}
|
||||
|
||||
conn := &ICMPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
},
|
||||
ICMPType: typ,
|
||||
ICMPCode: code,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
// TrackOutbound records an outbound ICMP Echo Request
|
||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections[key] = conn
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
copyIP(dstIPCopy, dstIP)
|
||||
|
||||
conn = &ICMPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
},
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New ICMP connection %v", key)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||
return false
|
||||
}
|
||||
|
||||
key := ICMPConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
ID: id,
|
||||
}
|
||||
key := makeICMPKey(dstIP, srcIP, id, seq)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.ID == id &&
|
||||
conn.Sequence == seq
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.tickerCancel()
|
||||
|
||||
func (t *ICMPTracker) cleanupRoutine() {
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-ctx.Done():
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanup() {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *ICMPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
|
||||
t.mutex.Lock()
|
||||
for _, conn := range t.connections {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
}
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
ICMPType: conn.ICMPType,
|
||||
ICMPCode: conn.ICMPCode,
|
||||
RxPackets: conn.PacketsRx.Load(),
|
||||
TxPackets: conn.PacketsTx.Load(),
|
||||
RxBytes: conn.BytesRx.Load(),
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
Type: nftypes.TypeStart,
|
||||
RuleID: ruleID,
|
||||
Direction: direction,
|
||||
Protocol: nftypes.ICMP,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
ICMPType: typ,
|
||||
ICMPCode: code,
|
||||
// makeICMPKey creates an ICMP connection key
|
||||
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
||||
return ICMPConnKey{
|
||||
SrcIP: MakeIPAddr(srcIP),
|
||||
DstIP: MakeIPAddr(dstIP),
|
||||
ID: id,
|
||||
Sequence: seq,
|
||||
}
|
||||
if direction == nftypes.Ingress {
|
||||
fields.RxPackets = 1
|
||||
fields.RxBytes = uint64(size)
|
||||
} else {
|
||||
fields.TxPackets = 1
|
||||
fields.TxBytes = uint64(size)
|
||||
}
|
||||
t.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
@@ -1,39 +1,39 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkICMPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,16 +3,12 @@ package conntrack
|
||||
// TODO: Send RST packets for invalid/timed-out connections
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,11 +19,11 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
TCPFin uint8 = 0x01
|
||||
TCPSyn uint8 = 0x02
|
||||
TCPAck uint8 = 0x10
|
||||
TCPFin uint8 = 0x01
|
||||
TCPRst uint8 = 0x04
|
||||
TCPPush uint8 = 0x08
|
||||
TCPAck uint8 = 0x10
|
||||
TCPUrg uint8 = 0x20
|
||||
)
|
||||
|
||||
@@ -41,36 +37,7 @@ const (
|
||||
)
|
||||
|
||||
// TCPState represents the state of a TCP connection
|
||||
type TCPState int32
|
||||
|
||||
func (s TCPState) String() string {
|
||||
switch s {
|
||||
case TCPStateNew:
|
||||
return "New"
|
||||
case TCPStateSynSent:
|
||||
return "SYN Sent"
|
||||
case TCPStateSynReceived:
|
||||
return "SYN Received"
|
||||
case TCPStateEstablished:
|
||||
return "Established"
|
||||
case TCPStateFinWait1:
|
||||
return "FIN Wait 1"
|
||||
case TCPStateFinWait2:
|
||||
return "FIN Wait 2"
|
||||
case TCPStateClosing:
|
||||
return "Closing"
|
||||
case TCPStateTimeWait:
|
||||
return "Time Wait"
|
||||
case TCPStateCloseWait:
|
||||
return "Close Wait"
|
||||
case TCPStateLastAck:
|
||||
return "Last ACK"
|
||||
case TCPStateClosed:
|
||||
return "Closed"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
type TCPState int
|
||||
|
||||
const (
|
||||
TCPStateNew TCPState = iota
|
||||
@@ -86,38 +53,30 @@ const (
|
||||
TCPStateClosed
|
||||
)
|
||||
|
||||
// TCPConnKey uniquely identifies a TCP connection
|
||||
type TCPConnKey struct {
|
||||
SrcIP [16]byte
|
||||
DstIP [16]byte
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// TCPConnTrack represents a TCP connection state
|
||||
type TCPConnTrack struct {
|
||||
BaseConnTrack
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
state atomic.Int32
|
||||
tombstone atomic.Bool
|
||||
State TCPState
|
||||
established atomic.Bool
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// GetState safely retrieves the current state
|
||||
func (t *TCPConnTrack) GetState() TCPState {
|
||||
return TCPState(t.state.Load())
|
||||
// IsEstablished safely checks if connection is established
|
||||
func (t *TCPConnTrack) IsEstablished() bool {
|
||||
return t.established.Load()
|
||||
}
|
||||
|
||||
// SetState safely updates the current state
|
||||
func (t *TCPConnTrack) SetState(state TCPState) {
|
||||
t.state.Store(int32(state))
|
||||
}
|
||||
|
||||
// CompareAndSwapState atomically changes the state from old to new if current == old
|
||||
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
|
||||
return t.state.CompareAndSwap(int32(old), int32(newState))
|
||||
}
|
||||
|
||||
// IsTombstone safely checks if the connection is marked for deletion
|
||||
func (t *TCPConnTrack) IsTombstone() bool {
|
||||
return t.tombstone.Load()
|
||||
}
|
||||
|
||||
// SetTombstone safely marks the connection for deletion
|
||||
func (t *TCPConnTrack) SetTombstone() {
|
||||
t.tombstone.Store(true)
|
||||
// SetEstablished safely sets the established state
|
||||
func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||
t.established.Store(state)
|
||||
}
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
@@ -126,234 +85,192 @@ type TCPTracker struct {
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
done chan struct{}
|
||||
timeout time.Duration
|
||||
waitTimeout time.Duration
|
||||
flowLogger nftypes.FlowLogger
|
||||
ipPool *PreallocatedIPs
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||
waitTimeout := TimeWaitTimeout
|
||||
if timeout == 0 {
|
||||
timeout = DefaultTCPTimeout
|
||||
} else {
|
||||
waitTimeout = timeout / 45
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
||||
tracker := &TCPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
done: make(chan struct{}),
|
||||
timeout: timeout,
|
||||
waitTimeout: waitTimeout,
|
||||
flowLogger: flowLogger,
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
go tracker.cleanupRoutine()
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
return key, true
|
||||
}
|
||||
|
||||
return key, false
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound TCP connection
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
conn := &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
},
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
|
||||
conn.tombstone.Store(false)
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
|
||||
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
// TrackOutbound processes an outbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
||||
// Create key before lock
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections[key] = conn
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
// Use preallocated IPs
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
copyIP(dstIPCopy, dstIP)
|
||||
|
||||
conn = &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
},
|
||||
State: TCPStateNew,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
conn.established.Store(false)
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
// Lock individual connection for state update
|
||||
conn.Lock()
|
||||
t.updateState(conn, flags, true)
|
||||
conn.Unlock()
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
SrcPort: dstPort,
|
||||
DstPort: srcPort,
|
||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists || conn.IsTombstone() {
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
currentState := conn.GetState()
|
||||
if !t.isValidStateForFlags(currentState, flags) {
|
||||
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
// allow all flags for established for now
|
||||
if currentState == TCPStateEstablished {
|
||||
// Handle RST packets
|
||||
if flags&TCPRst != 0 {
|
||||
conn.Lock()
|
||||
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
conn.Unlock()
|
||||
return true
|
||||
}
|
||||
conn.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
t.updateState(key, conn, flags, nftypes.Ingress, size)
|
||||
return true
|
||||
conn.Lock()
|
||||
t.updateState(conn, flags, false)
|
||||
conn.UpdateLastSeen()
|
||||
isEstablished := conn.IsEstablished()
|
||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||
conn.Unlock()
|
||||
|
||||
return isEstablished || isValidState
|
||||
}
|
||||
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(packetDir, size)
|
||||
|
||||
currentState := conn.GetState()
|
||||
|
||||
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||
// Handle RST flag specially - it always causes transition to closed
|
||||
if flags&TCPRst != 0 {
|
||||
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
|
||||
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
return
|
||||
}
|
||||
|
||||
var newState TCPState
|
||||
switch currentState {
|
||||
switch conn.State {
|
||||
case TCPStateNew:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
if conn.Direction == nftypes.Egress {
|
||||
newState = TCPStateSynSent
|
||||
} else {
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
conn.State = TCPStateSynSent
|
||||
}
|
||||
|
||||
case TCPStateSynSent:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if packetDir != conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
if isOutbound {
|
||||
conn.State = TCPStateSynReceived
|
||||
} else {
|
||||
// Simultaneous open
|
||||
newState = TCPStateSynReceived
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateSynReceived:
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
}
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
}
|
||||
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPFin != 0 {
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateFinWait1
|
||||
if isOutbound {
|
||||
conn.State = TCPStateFinWait1
|
||||
} else {
|
||||
newState = TCPStateCloseWait
|
||||
conn.State = TCPStateCloseWait
|
||||
}
|
||||
conn.SetEstablished(false)
|
||||
}
|
||||
|
||||
case TCPStateFinWait1:
|
||||
if packetDir != conn.Direction {
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
newState = TCPStateFinWait2
|
||||
}
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
// Simultaneous close - both sides sent FIN
|
||||
conn.State = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPAck != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
}
|
||||
|
||||
case TCPStateFinWait2:
|
||||
if flags&TCPFin != 0 {
|
||||
newState = TCPStateTimeWait
|
||||
conn.State = TCPStateTimeWait
|
||||
}
|
||||
|
||||
case TCPStateClosing:
|
||||
if flags&TCPAck != 0 {
|
||||
newState = TCPStateTimeWait
|
||||
conn.State = TCPStateTimeWait
|
||||
// Keep established = false from previous state
|
||||
|
||||
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
if flags&TCPFin != 0 {
|
||||
newState = TCPStateLastAck
|
||||
conn.State = TCPStateLastAck
|
||||
}
|
||||
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
newState = TCPStateClosed
|
||||
conn.State = TCPStateClosed
|
||||
|
||||
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
}
|
||||
|
||||
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
|
||||
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
|
||||
case TCPStateTimeWait:
|
||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
||||
// This is handled by the cleanup routine
|
||||
|
||||
switch newState {
|
||||
case TCPStateTimeWait:
|
||||
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -362,22 +279,18 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
if flags&TCPRst != 0 {
|
||||
if state == TCPStateSynSent {
|
||||
return flags&TCPAck != 0
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
switch state {
|
||||
case TCPStateNew:
|
||||
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
||||
case TCPStateSynSent:
|
||||
// TODO: support simultaneous open
|
||||
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
||||
case TCPStateSynReceived:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPRst != 0 {
|
||||
return true
|
||||
}
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateFinWait1:
|
||||
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||
@@ -394,20 +307,20 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
case TCPStateLastAck:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateClosed:
|
||||
// Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK
|
||||
// Accept retransmitted ACKs in closed state
|
||||
// This is important because the final ACK might be lost
|
||||
// and the peer will retransmit their FIN-ACK
|
||||
return flags&TCPAck != 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.cleanupTicker.Stop()
|
||||
|
||||
func (t *TCPTracker) cleanupRoutine() {
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-ctx.Done():
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -418,43 +331,39 @@ func (t *TCPTracker) cleanup() {
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if conn.IsTombstone() {
|
||||
// Clean up tombstoned connections without sending an event
|
||||
delete(t.connections, key)
|
||||
continue
|
||||
}
|
||||
|
||||
var timeout time.Duration
|
||||
currentState := conn.GetState()
|
||||
switch currentState {
|
||||
case TCPStateTimeWait:
|
||||
timeout = t.waitTimeout
|
||||
case TCPStateEstablished:
|
||||
switch {
|
||||
case conn.State == TCPStateTimeWait:
|
||||
timeout = TimeWaitTimeout
|
||||
case conn.IsEstablished():
|
||||
timeout = t.timeout
|
||||
default:
|
||||
timeout = TCPHandshakeTimeout
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(timeout) {
|
||||
lastSeen := conn.GetLastSeen()
|
||||
if time.Since(lastSeen) > timeout {
|
||||
// Return IPs to pool
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
|
||||
// event already handled by state change
|
||||
if currentState != TCPStateTimeWait {
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *TCPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
|
||||
// Clean up all remaining IPs
|
||||
t.mutex.Lock()
|
||||
for _, conn := range t.connections {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
}
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
@@ -472,21 +381,3 @@ func isValidFlagCombination(flags uint8) bool {
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.TCP,
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
SourcePort: conn.SourcePort,
|
||||
DestPort: conn.DestPort,
|
||||
RxPackets: conn.PacketsRx.Load(),
|
||||
TxPackets: conn.PacketsTx.Load(),
|
||||
RxBytes: conn.BytesRx.Load(),
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.cleanup()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,20 +1,19 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTCPStateMachine(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
dstIP := net.ParseIP("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
@@ -59,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
|
||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
|
||||
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||
})
|
||||
}
|
||||
@@ -77,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// Send initial SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
|
||||
// Receive SYN-ACK
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
// Send ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
|
||||
// Test data transfer
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
||||
require.True(t, valid, "Data should be allowed after handshake")
|
||||
},
|
||||
},
|
||||
@@ -100,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||
require.True(t, valid, "ACK for FIN should be allowed")
|
||||
|
||||
// Receive FIN from other side
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||
require.True(t, valid, "FIN should be allowed")
|
||||
|
||||
// Send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -123,8 +122,11 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Receive RST
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
require.True(t, valid, "RST should be allowed for established connection")
|
||||
|
||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||
// The connection will be cleaned up by timeout
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -136,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Both sides send FIN+ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
||||
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||
|
||||
// Both sides send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
||||
require.True(t, valid, "Final ACKs should be allowed")
|
||||
},
|
||||
},
|
||||
@@ -152,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
tt.test(t)
|
||||
})
|
||||
}
|
||||
@@ -160,11 +162,11 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRSTHandling(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcIP := net.ParseIP("100.64.0.1")
|
||||
dstIP := net.ParseIP("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
@@ -179,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
|
||||
name: "RST in established",
|
||||
setupState: func() {
|
||||
// Establish connection first
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
},
|
||||
sendRST: func() {
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
},
|
||||
wantValid: true,
|
||||
desc: "Should accept RST for established connection",
|
||||
@@ -193,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
|
||||
name: "RST without connection",
|
||||
setupState: func() {},
|
||||
sendRST: func() {
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
},
|
||||
wantValid: false,
|
||||
desc: "Should reject RST without connection",
|
||||
@@ -206,455 +208,101 @@ func TestRSTHandling(t *testing.T) {
|
||||
tt.sendRST()
|
||||
|
||||
// Verify connection state is as expected
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
conn := tracker.connections[key]
|
||||
if tt.wantValid {
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
require.Equal(t, TCPStateClosed, conn.State)
|
||||
require.False(t, conn.IsEstablished())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPRetransmissions(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test SYN retransmission
|
||||
t.Run("SYN Retransmission", func(t *testing.T) {
|
||||
// Initial SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Retransmit SYN (should not affect the state machine)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Verify we're still in SYN-SENT state
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState())
|
||||
|
||||
// Complete the handshake
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// Verify we're in ESTABLISHED state
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
})
|
||||
|
||||
// Test ACK retransmission in established state
|
||||
t.Run("ACK Retransmission", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Retransmit ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// State should remain ESTABLISHED
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
})
|
||||
|
||||
// Test FIN retransmission
|
||||
t.Run("FIN Retransmission", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Retransmit FIN (should not change state)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPDataTransfer(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Data Transfer", func(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||
|
||||
// Receive ACK for data
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
|
||||
require.True(t, valid)
|
||||
|
||||
// Receive data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
|
||||
require.True(t, valid)
|
||||
|
||||
// Send ACK for received data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
|
||||
// State should remain ESTABLISHED
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
|
||||
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
|
||||
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
|
||||
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPHalfClosedConnections(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test half-closed connection: local end closes, remote end continues sending data
|
||||
t.Run("Local Close, Remote Data", func(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
|
||||
// Remote end can still send data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
|
||||
require.True(t, valid)
|
||||
|
||||
// We can still ACK their data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// Receive FIN from remote end
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// State should remain TIME-WAIT (waiting for possible retransmissions)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
})
|
||||
|
||||
// Test half-closed connection: remote end closes, local end continues sending data
|
||||
t.Run("Remote Close, Local Data", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Receive FIN from remote
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||
|
||||
// We can still send data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||
|
||||
// Remote can still ACK our data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
// Send our FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Receive final ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPAbnormalSequences(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test handling of unsolicited RST in various states
|
||||
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
|
||||
// Send SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Receive unsolicited RST (without proper ACK)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
|
||||
|
||||
// Receive RST with proper ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
require.True(t, conn.IsTombstone())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPTimeoutHandling(t *testing.T) {
|
||||
// Create tracker with a very short timeout for testing
|
||||
shortTimeout := 100 * time.Millisecond
|
||||
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Connection Timeout", func(t *testing.T) {
|
||||
// Establish a connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Wait for the connection to timeout
|
||||
time.Sleep(2 * shortTimeout)
|
||||
|
||||
// Force cleanup
|
||||
tracker.cleanup()
|
||||
|
||||
// Connection should be removed
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "Connection should be removed after timeout")
|
||||
})
|
||||
|
||||
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Complete the connection close to enter TIME_WAIT
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// TIME_WAIT should have its own timeout value (usually 2*MSL)
|
||||
// For the test, we're using a short timeout
|
||||
time.Sleep(2 * shortTimeout)
|
||||
|
||||
tracker.cleanup()
|
||||
|
||||
// Connection should be removed
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSynFlood(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
basePort := uint16(10000)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Create a large number of SYN packets to simulate a SYN flood
|
||||
for i := uint16(0); i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Check that we're tracking all connections
|
||||
require.Equal(t, 1000, len(tracker.connections))
|
||||
|
||||
// Now simulate SYN timeout
|
||||
var oldConns int
|
||||
tracker.mutex.Lock()
|
||||
for _, conn := range tracker.connections {
|
||||
if conn.GetState() == TCPStateSynSent {
|
||||
// Make the connection appear old
|
||||
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
|
||||
oldConns++
|
||||
}
|
||||
}
|
||||
tracker.mutex.Unlock()
|
||||
require.Equal(t, 1000, oldConns)
|
||||
|
||||
// Run cleanup
|
||||
tracker.cleanup()
|
||||
|
||||
// Check that stale connections were cleaned up
|
||||
require.Equal(t, 0, len(tracker.connections))
|
||||
}
|
||||
|
||||
func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
clientIP := netip.MustParseAddr("100.64.0.1")
|
||||
serverIP := netip.MustParseAddr("100.64.0.2")
|
||||
clientPort := uint16(12345)
|
||||
serverPort := uint16(80)
|
||||
|
||||
// 1. Client sends SYN (we receive it as inbound)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: clientIP,
|
||||
DstIP: serverIP,
|
||||
SrcPort: clientPort,
|
||||
DstPort: serverPort,
|
||||
}
|
||||
|
||||
tracker.mutex.RLock()
|
||||
conn := tracker.connections[key]
|
||||
tracker.mutex.RUnlock()
|
||||
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
|
||||
|
||||
// 2. Server sends SYN-ACK response
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
|
||||
// 3. Client sends ACK to complete handshake
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
|
||||
|
||||
// 4. Test data transfer
|
||||
// Client sends data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
|
||||
|
||||
// Server sends ACK for data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||
|
||||
// Server sends data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
|
||||
|
||||
// Client sends ACK for data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
|
||||
// Verify state and counters
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
|
||||
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
|
||||
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
|
||||
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
|
||||
}
|
||||
|
||||
// Helper to establish a TCP connection
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
||||
t.Helper()
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
||||
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
||||
}
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.cleanup()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -22,8 +18,6 @@ const (
|
||||
// UDPConnTrack represents a UDP connection state
|
||||
type UDPConnTrack struct {
|
||||
BaseConnTrack
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
}
|
||||
|
||||
// UDPTracker manages UDP connection states
|
||||
@@ -32,126 +26,89 @@ type UDPTracker struct {
|
||||
connections map[ConnKey]*UDPConnTrack
|
||||
timeout time.Duration
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
flowLogger nftypes.FlowLogger
|
||||
done chan struct{}
|
||||
ipPool *PreallocatedIPs
|
||||
}
|
||||
|
||||
// NewUDPTracker creates a new UDP connection tracker
|
||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
|
||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
||||
if timeout == 0 {
|
||||
timeout = DefaultUDPTimeout
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &UDPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
flowLogger: flowLogger,
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
go tracker.cleanupRoutine()
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound UDP connection
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||
}
|
||||
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
return key, true
|
||||
}
|
||||
|
||||
return key, false
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
|
||||
conn := &UDPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
FlowId: uuid.New(),
|
||||
Direction: direction,
|
||||
SourceIP: srcIP,
|
||||
DestIP: dstIP,
|
||||
},
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections[key] = conn
|
||||
conn, exists := t.connections[key]
|
||||
if !exists {
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, srcIP)
|
||||
copyIP(dstIPCopy, dstIP)
|
||||
|
||||
conn = &UDPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
},
|
||||
}
|
||||
conn.UpdateLastSeen()
|
||||
t.connections[key] = conn
|
||||
|
||||
t.logger.Trace("New UDP connection: %v", conn)
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
conn.UpdateLastSeen()
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
SrcPort: dstPort,
|
||||
DstPort: srcPort,
|
||||
}
|
||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
||||
conn.DestPort == srcPort &&
|
||||
conn.SourcePort == dstPort
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically removes stale connections
|
||||
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.cleanupTicker.Stop()
|
||||
|
||||
func (t *UDPTracker) cleanupRoutine() {
|
||||
for {
|
||||
select {
|
||||
case <-t.cleanupTicker.C:
|
||||
t.cleanup()
|
||||
case <-ctx.Done():
|
||||
case <-t.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -163,58 +120,44 @@ func (t *UDPTracker) cleanup() {
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *UDPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
t.cleanupTicker.Stop()
|
||||
close(t.done)
|
||||
|
||||
t.mutex.Lock()
|
||||
for _, conn := range t.connections {
|
||||
t.ipPool.Put(conn.SourceIP)
|
||||
t.ipPool.Put(conn.DestIP)
|
||||
}
|
||||
t.connections = nil
|
||||
t.mutex.Unlock()
|
||||
}
|
||||
|
||||
// GetConnection safely retrieves a connection state
|
||||
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
|
||||
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
conn, exists := t.connections[key]
|
||||
return conn, exists
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return conn, true
|
||||
}
|
||||
|
||||
// Timeout returns the configured timeout duration for the tracker
|
||||
func (t *UDPTracker) Timeout() time.Duration {
|
||||
return t.timeout
|
||||
}
|
||||
|
||||
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
|
||||
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: conn.FlowId,
|
||||
Type: typ,
|
||||
RuleID: ruleID,
|
||||
Direction: conn.Direction,
|
||||
Protocol: nftypes.UDP,
|
||||
SourceIP: conn.SourceIP,
|
||||
DestIP: conn.DestIP,
|
||||
SourcePort: conn.SourcePort,
|
||||
DestPort: conn.DestPort,
|
||||
RxPackets: conn.PacketsRx.Load(),
|
||||
TxPackets: conn.PacketsTx.Load(),
|
||||
RxBytes: conn.BytesRx.Load(),
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -30,59 +29,54 @@ func TestNewUDPTracker(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
|
||||
tracker := NewUDPTracker(tt.timeout, logger)
|
||||
assert.NotNil(t, tracker)
|
||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||
assert.NotNil(t, tracker.connections)
|
||||
assert.NotNil(t, tracker.cleanupTicker)
|
||||
assert.NotNil(t, tracker.tickerCancel)
|
||||
assert.NotNil(t, tracker.done)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
dstIP := net.ParseIP("192.168.1.3")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Verify connection was tracked
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
||||
conn, exists := tracker.connections[key]
|
||||
require.True(t, exists)
|
||||
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
||||
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
|
||||
assert.True(t, conn.SourceIP.Equal(srcIP))
|
||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
||||
assert.Equal(t, srcPort, conn.SourcePort)
|
||||
assert.Equal(t, dstPort, conn.DestPort)
|
||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||
}
|
||||
|
||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
||||
tracker := NewUDPTracker(1*time.Second, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||
srcIP := net.ParseIP("192.168.1.2")
|
||||
dstIP := net.ParseIP("192.168.1.3")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(53)
|
||||
|
||||
// Track outbound connection
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
sleep time.Duration
|
||||
@@ -99,7 +93,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "invalid source IP",
|
||||
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||
srcIP: net.ParseIP("192.168.1.4"),
|
||||
dstIP: srcIP,
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
@@ -109,7 +103,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
{
|
||||
name: "invalid destination IP",
|
||||
srcIP: dstIP,
|
||||
dstIP: netip.MustParseAddr("192.168.1.4"),
|
||||
dstIP: net.ParseIP("192.168.1.4"),
|
||||
srcPort: dstPort,
|
||||
dstPort: srcPort,
|
||||
sleep: 0,
|
||||
@@ -149,7 +143,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||
if tt.sleep > 0 {
|
||||
time.Sleep(tt.sleep)
|
||||
}
|
||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
|
||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
@@ -160,45 +154,42 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
timeout := 50 * time.Millisecond
|
||||
cleanupInterval := 25 * time.Millisecond
|
||||
|
||||
ctx, tickerCancel := context.WithCancel(context.Background())
|
||||
defer tickerCancel()
|
||||
|
||||
// Create tracker with custom cleanup interval
|
||||
tracker := &UDPTracker{
|
||||
connections: make(map[ConnKey]*UDPConnTrack),
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||
tickerCancel: tickerCancel,
|
||||
done: make(chan struct{}),
|
||||
ipPool: NewPreallocatedIPs(),
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
go tracker.cleanupRoutine()
|
||||
|
||||
// Add some connections
|
||||
connections := []struct {
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcIP net.IP
|
||||
dstIP net.IP
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{
|
||||
srcIP: netip.MustParseAddr("192.168.1.2"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.3"),
|
||||
srcIP: net.ParseIP("192.168.1.2"),
|
||||
dstIP: net.ParseIP("192.168.1.3"),
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
},
|
||||
{
|
||||
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||
dstIP: netip.MustParseAddr("192.168.1.5"),
|
||||
srcIP: net.ParseIP("192.168.1.4"),
|
||||
dstIP: net.ParseIP("192.168.1.5"),
|
||||
srcPort: 12346,
|
||||
dstPort: 53,
|
||||
},
|
||||
}
|
||||
|
||||
for _, conn := range connections {
|
||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
|
||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
|
||||
}
|
||||
|
||||
// Verify initial connections
|
||||
@@ -220,33 +211,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
|
||||
func BenchmarkUDPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
srcIP := net.ParseIP("192.168.1.1")
|
||||
dstIP := net.ParseIP("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
@@ -81,10 +79,3 @@ func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
// src and remote is swapped
|
||||
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||
}
|
||||
|
||||
@@ -4,9 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
@@ -19,9 +17,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -32,20 +28,17 @@ const (
|
||||
)
|
||||
|
||||
type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
// ruleIdMap is used to store the rule ID for a given connection
|
||||
ruleIdMap sync.Map
|
||||
logger *nblog.Logger
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
ip net.IP
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
@@ -71,11 +64,12 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
ones, _ := iface.Address().Network.Mask.Size()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
PrefixLen: iface.Address().Network.Bits(),
|
||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||
PrefixLen: ones,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -108,14 +102,13 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
udpForwarder: newUDPForwarder(mtu, logger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
ip: iface.Address().IP,
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
@@ -166,39 +159,8 @@ func (f *Forwarder) Stop() {
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
if f.netstack && f.ip.Equal(addr) {
|
||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
return addr.AsSlice()
|
||||
}
|
||||
|
||||
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
|
||||
key := buildKey(srcIP, dstIP, srcPort, dstPort)
|
||||
f.ruleIdMap.LoadOrStore(key, ruleID)
|
||||
}
|
||||
|
||||
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||
return value.([]byte), true
|
||||
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
||||
return value.([]byte), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||
if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||
return
|
||||
}
|
||||
f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort))
|
||||
}
|
||||
|
||||
func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey {
|
||||
return conntrack.ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,30 +3,14 @@ package forwarder
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// handleICMP handles ICMP packets from the network stack
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
icmpType := uint8(icmpHdr.Type())
|
||||
icmpCode := uint8(icmpHdr.Code())
|
||||
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||
// dont process our own replies
|
||||
return true
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -34,55 +18,70 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
||||
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err)
|
||||
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
// Get the complete ICMP message (header + data)
|
||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||
payload := fullPacket.AsSlice()
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
switch icmpHdr.Type() {
|
||||
case header.ICMPv4Echo:
|
||||
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
||||
case header.ICMPv4EchoReply:
|
||||
// dont process our own replies
|
||||
return true
|
||||
default:
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
||||
_, err = conn.WriteTo(payload, dst)
|
||||
if err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||
rxBytes := pkt.Size()
|
||||
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
||||
if _, err := conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
||||
id, icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||
return 0
|
||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
n, _, err := conn.ReadFrom(response)
|
||||
if err != nil {
|
||||
if !isTimeout(err) {
|
||||
f.logger.Error("forwarder: Failed to read ICMP response: %v", err)
|
||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||
}
|
||||
return 0
|
||||
return true
|
||||
}
|
||||
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
@@ -101,54 +100,10 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
|
||||
fullPacket = append(fullPacket, response[:n]...)
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error("forwarder: Failed to inject ICMP response: %v", err)
|
||||
|
||||
return 0
|
||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
// sendICMPEvent stores flow events for ICMP packets
|
||||
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) {
|
||||
var rxPackets, txPackets uint64
|
||||
if rxBytes > 0 {
|
||||
rxPackets = 1
|
||||
}
|
||||
if txBytes > 0 {
|
||||
txPackets = 1
|
||||
}
|
||||
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.ICMP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
RxPackets: rxPackets,
|
||||
TxPackets: txPackets,
|
||||
}
|
||||
|
||||
if typ == nftypes.TypeStart {
|
||||
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||
fields.RuleID = ruleId
|
||||
}
|
||||
} else {
|
||||
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -5,40 +5,24 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// handleTCP is called by the TCP forwarder for new connections.
|
||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||
}
|
||||
}()
|
||||
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
|
||||
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -60,105 +44,47 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
success = true
|
||||
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
|
||||
f.logger.Trace("forwarder: established TCP connection %v", id)
|
||||
|
||||
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||
go f.proxyTCP(id, inConn, outConn, ep)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
// Close connections and endpoint.
|
||||
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
||||
defer func() {
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil && !isClosedError(err) {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
// Create context for managing the proxy goroutines
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
var (
|
||||
bytesFromInToOut int64 // bytes from client to server (tx for client)
|
||||
bytesFromOutToIn int64 // bytes from server to client (rx for client)
|
||||
errInToOut error
|
||||
errOutToIn error
|
||||
)
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
_, err := io.Copy(outConn, inConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
||||
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
_, err := io.Copy(inConn, outConn)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errInToOut != nil {
|
||||
if !isClosedError(errInToOut) {
|
||||
f.logger.Error("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||
}
|
||||
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
||||
return
|
||||
}
|
||||
if errOutToIn != nil {
|
||||
if !isClosedError(errOutToIn) {
|
||||
f.logger.Error("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
|
||||
}
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
rxPackets = tcpStats.SegmentsSent.Value()
|
||||
txPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
|
||||
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.TCP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
SourcePort: id.RemotePort,
|
||||
DestPort: id.LocalPort,
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
RxPackets: rxPackets,
|
||||
TxPackets: txPackets,
|
||||
}
|
||||
|
||||
if typ == nftypes.TypeStart {
|
||||
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||
fields.RuleID = ruleId
|
||||
}
|
||||
} else {
|
||||
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
@@ -5,12 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
@@ -18,7 +16,6 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -31,17 +28,15 @@ type udpPacketConn struct {
|
||||
lastSeen atomic.Int64
|
||||
cancel context.CancelFunc
|
||||
ep tcpip.Endpoint
|
||||
flowID uuid.UUID
|
||||
}
|
||||
|
||||
type udpForwarder struct {
|
||||
sync.RWMutex
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||
bufPool sync.Pool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
logger *nblog.Logger
|
||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||
bufPool sync.Pool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type idleConn struct {
|
||||
@@ -49,14 +44,13 @@ type idleConn struct {
|
||||
conn *udpPacketConn
|
||||
}
|
||||
|
||||
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &udpForwarder{
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: logger,
|
||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, mtu)
|
||||
@@ -78,10 +72,10 @@ func (f *udpForwarder) Stop() {
|
||||
for id, conn := range f.conns {
|
||||
conn.cancel()
|
||||
if err := conn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := conn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
|
||||
conn.ep.Close()
|
||||
@@ -112,10 +106,10 @@ func (f *udpForwarder) cleanup() {
|
||||
for _, idle := range idleConns {
|
||||
idle.conn.cancel()
|
||||
if err := idle.conn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
|
||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
||||
}
|
||||
if err := idle.conn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
||||
}
|
||||
|
||||
idle.conn.ep.Close()
|
||||
@@ -124,7 +118,7 @@ func (f *udpForwarder) cleanup() {
|
||||
delete(f.conns, idle.id)
|
||||
f.Unlock()
|
||||
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -143,24 +137,14 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
_, exists := f.udpForwarder.conns[id]
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
|
||||
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
||||
return
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||
}
|
||||
}()
|
||||
|
||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||
if err != nil {
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
||||
// TODO: Send ICMP error message
|
||||
return
|
||||
}
|
||||
@@ -171,7 +155,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if epErr != nil {
|
||||
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -184,7 +168,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
outConn: outConn,
|
||||
cancel: connCancel,
|
||||
ep: ep,
|
||||
flowID: flowID,
|
||||
}
|
||||
pConn.updateLastSeen()
|
||||
|
||||
@@ -194,114 +177,58 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
f.udpForwarder.Unlock()
|
||||
pConn.cancel()
|
||||
if err := inConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
success = true
|
||||
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
|
||||
|
||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
defer func() {
|
||||
pConn.cancel()
|
||||
if err := pConn.conn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||
if err := pConn.conn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
||||
}
|
||||
if err := pConn.outConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
if err := pConn.outConn.Close(); err != nil {
|
||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
var txBytes, rxBytes int64
|
||||
var outboundErr, inboundErr error
|
||||
|
||||
// outbound->inbound: copy from pConn.conn to pConn.outConn
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||
}()
|
||||
|
||||
// inbound->outbound: copy from pConn.outConn to pConn.conn
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if outboundErr != nil && !isClosedError(outboundErr) {
|
||||
f.logger.Error("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr)
|
||||
}
|
||||
if inboundErr != nil && !isClosedError(inboundErr) {
|
||||
f.logger.Error("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr)
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||
// fields are flipped since this is the in conn
|
||||
rxPackets = udpStats.PacketsSent.Value()
|
||||
txPackets = udpStats.PacketsReceived.Value()
|
||||
}
|
||||
|
||||
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
// sendUDPEvent stores flow events for UDP connections
|
||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
fields := nftypes.EventFields{
|
||||
FlowID: flowID,
|
||||
Type: typ,
|
||||
Direction: nftypes.Ingress,
|
||||
Protocol: nftypes.UDP,
|
||||
// TODO: handle ipv6
|
||||
SourceIP: srcIp,
|
||||
DestIP: dstIp,
|
||||
SourcePort: id.RemotePort,
|
||||
DestPort: id.LocalPort,
|
||||
RxBytes: rxBytes,
|
||||
TxBytes: txBytes,
|
||||
RxPackets: rxPackets,
|
||||
TxPackets: txPackets,
|
||||
}
|
||||
|
||||
if typ == nftypes.TypeStart {
|
||||
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||
fields.RuleID = ruleId
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
||||
return
|
||||
case err := <-errChan:
|
||||
if err != nil && !isClosedError(err) {
|
||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||
}
|
||||
} else {
|
||||
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
||||
return
|
||||
}
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) updateLastSeen() {
|
||||
@@ -313,37 +240,38 @@ func (c *udpPacketConn) getIdleDuration() time.Duration {
|
||||
return time.Since(lastSeen)
|
||||
}
|
||||
|
||||
// copy reads from src and writes to dst.
|
||||
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) {
|
||||
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
||||
bufp := bufPool.Get().(*[]byte)
|
||||
defer bufPool.Put(bufp)
|
||||
buffer := *bufp
|
||||
var totalBytes int64 = 0
|
||||
|
||||
if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||
return fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||
return fmt.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return totalBytes, ctx.Err()
|
||||
}
|
||||
|
||||
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||
return totalBytes, fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("read from %s: %w", direction, err)
|
||||
}
|
||||
return totalBytes, fmt.Errorf("read from %s: %w", direction, err)
|
||||
}
|
||||
|
||||
nWritten, err := dst.Write(buffer[:n])
|
||||
if err != nil {
|
||||
return totalBytes, fmt.Errorf("write to %s: %w", direction, err)
|
||||
}
|
||||
_, err = dst.Write(buffer[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("write to %s: %w", direction, err)
|
||||
}
|
||||
|
||||
totalBytes += int64(nWritten)
|
||||
c.updateLastSeen()
|
||||
c.updateLastSeen()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package uspfilter
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -14,13 +13,8 @@ import (
|
||||
type localIPManager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// fixed-size high array for upper byte of a IPv4 address
|
||||
ipv4Bitmap [256]*ipv4LowBitmap
|
||||
}
|
||||
|
||||
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
||||
type ipv4LowBitmap struct {
|
||||
bitmap [8192]uint32
|
||||
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
||||
ipv4Bitmap [1 << 16]uint32
|
||||
}
|
||||
|
||||
func newLocalIPManager() *localIPManager {
|
||||
@@ -32,61 +26,39 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||
if ipv4 == nil {
|
||||
return
|
||||
}
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
|
||||
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
if !ip.Is4() {
|
||||
return
|
||||
}
|
||||
ipv4 := ip.AsSlice()
|
||||
|
||||
high := uint16(ipv4[0])
|
||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||
|
||||
if bitmap[high] == nil {
|
||||
bitmap[high] = &ipv4LowBitmap{}
|
||||
}
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
bitmap[high].bitmap[index] |= 1 << bit
|
||||
|
||||
if _, exists := ipv4Set[ip]; !exists {
|
||||
ipv4Set[ip] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||
high := uint16(ip[0])
|
||||
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
||||
|
||||
if m.ipv4Bitmap[high] == nil {
|
||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
index := low / 32
|
||||
bit := low % 32
|
||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||
}
|
||||
|
||||
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||
if int(high) >= len(*newIPv4Bitmap) {
|
||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||
}
|
||||
ipStr := ip.String()
|
||||
if _, exists := ipv4Set[ipStr]; !exists {
|
||||
ipv4Set[ipStr] = struct{}{}
|
||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||
newIPv4Bitmap[high] |= 1 << (low % 32)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||
@@ -104,13 +76,7 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
|
||||
continue
|
||||
}
|
||||
|
||||
addr, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||
log.Debugf("process IP failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -123,14 +89,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||
ipv4Set := make(map[netip.Addr]struct{})
|
||||
var ipv4Addresses []netip.Addr
|
||||
var newIPv4Bitmap [1 << 16]uint32
|
||||
ipv4Set := make(map[string]struct{})
|
||||
var ipv4Addresses []string
|
||||
|
||||
// 127.0.0.0/8
|
||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||
for i := 0; i < 8192; i++ {
|
||||
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||
high := uint16(127) << 8
|
||||
for i := uint16(0); i < 256; i++ {
|
||||
newIPv4Bitmap[high|i] = 0xffffffff
|
||||
}
|
||||
|
||||
if iface != nil {
|
||||
@@ -156,13 +122,13 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||
if !ip.Is4() {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.checkBitmapBit(ip.AsSlice())
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
return m.checkBitmapBit(ipv4)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2,82 +2,90 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
func TestLocalIPManager(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupAddr wgaddr.Address
|
||||
testIP netip.Addr
|
||||
setupAddr iface.WGAddress
|
||||
testIP net.IP
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Localhost range",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||
testIP: net.ParseIP("127.0.0.2"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost standard address",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||
testIP: net.ParseIP("127.0.0.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Localhost range edge",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||
testIP: net.ParseIP("127.255.255.255"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP matches",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||
testIP: net.ParseIP("192.168.1.1"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Local IP doesn't match",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.0"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Local IP doesn't match - addresses 32 apart",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("192.168.1.1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||
testIP: net.ParseIP("192.168.1.2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
setupAddr: wgaddr.Address{
|
||||
IP: netip.MustParseAddr("fe80::1"),
|
||||
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
setupAddr: iface.WGAddress{
|
||||
IP: net.ParseIP("fe80::1"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("fe80::"),
|
||||
Mask: net.CIDRMask(64, 128),
|
||||
},
|
||||
},
|
||||
testIP: netip.MustParseAddr("fe80::1"),
|
||||
testIP: net.ParseIP("fe80::1"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
@@ -87,7 +95,7 @@ func TestLocalIPManager(t *testing.T) {
|
||||
manager := newLocalIPManager()
|
||||
|
||||
mock := &IFaceMock{
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return tt.setupAddr
|
||||
},
|
||||
}
|
||||
@@ -166,7 +174,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||
t.Logf("Testing %d IPs", len(tests))
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
|
||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||
})
|
||||
}
|
||||
@@ -183,8 +191,10 @@ func BenchmarkIPChecks(b *testing.B) {
|
||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||
}
|
||||
|
||||
// Setup bitmap
|
||||
bitmapManager := newLocalIPManager()
|
||||
// Setup bitmap version
|
||||
bitmapManager := &localIPManager{
|
||||
ipv4Bitmap: [1 << 16]uint32{},
|
||||
}
|
||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||
bitmapManager.setBitmapBit(ip)
|
||||
}
|
||||
@@ -237,7 +247,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
||||
|
||||
// Create two managers - one checks WG IP first, other checks it last
|
||||
b.Run("WG_First", func(b *testing.B) {
|
||||
bm := newLocalIPManager()
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
bm.setBitmapBit(wgIP)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@@ -246,7 +256,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("WG_Last", func(b *testing.B) {
|
||||
bm := newLocalIPManager()
|
||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||
// Fill with other IPs first
|
||||
for i := 0; i < 15; i++ {
|
||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package log provides a high-performance, non-blocking logger for userspace networking
|
||||
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
||||
package log
|
||||
|
||||
import (
|
||||
@@ -13,12 +13,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
maxBatchSize = 1024 * 16
|
||||
maxMessageSize = 1024 * 2
|
||||
maxBatchSize = 1024 * 16 // 16KB max batch size
|
||||
maxMessageSize = 1024 * 2 // 2KB per message
|
||||
bufferSize = 1024 * 256 // 256KB ring buffer
|
||||
defaultFlushInterval = 2 * time.Second
|
||||
logChannelSize = 1000
|
||||
)
|
||||
|
||||
// Level represents log severity
|
||||
type Level uint32
|
||||
|
||||
const (
|
||||
@@ -41,37 +42,32 @@ var levelStrings = map[Level]string{
|
||||
LevelTrace: "TRAC",
|
||||
}
|
||||
|
||||
type logMessage struct {
|
||||
level Level
|
||||
format string
|
||||
args []any
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
type Logger struct {
|
||||
output io.Writer
|
||||
level atomic.Uint32
|
||||
msgChannel chan logMessage
|
||||
shutdown chan struct{}
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
bufPool sync.Pool
|
||||
output io.Writer
|
||||
level atomic.Uint32
|
||||
buffer *ringBuffer
|
||||
shutdown chan struct{}
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Reusable buffer pool for formatting messages
|
||||
bufPool sync.Pool
|
||||
}
|
||||
|
||||
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
|
||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
l := &Logger{
|
||||
output: logrusLogger.Out,
|
||||
msgChannel: make(chan logMessage, logChannelSize),
|
||||
shutdown: make(chan struct{}),
|
||||
output: logrusLogger.Out,
|
||||
buffer: newRingBuffer(bufferSize),
|
||||
shutdown: make(chan struct{}),
|
||||
bufPool: sync.Pool{
|
||||
New: func() any {
|
||||
New: func() interface{} {
|
||||
// Pre-allocate buffer for message formatting
|
||||
b := make([]byte, 0, maxMessageSize)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
logrusLevel := logrusLogger.GetLevel()
|
||||
l.level.Store(uint32(logrusLevel))
|
||||
level := levelStrings[Level(logrusLevel)]
|
||||
@@ -83,149 +79,97 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||
return l
|
||||
}
|
||||
|
||||
// SetLevel sets the logging level
|
||||
func (l *Logger) SetLevel(level Level) {
|
||||
l.level.Store(uint32(level))
|
||||
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
func (l *Logger) log(level Level, format string, args ...any) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
|
||||
default:
|
||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
||||
*buf = (*buf)[:0]
|
||||
|
||||
// Timestamp
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Level
|
||||
*buf = append(*buf, levelStrings[level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Message
|
||||
if len(args) > 0 {
|
||||
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
||||
} else {
|
||||
*buf = append(*buf, format...)
|
||||
}
|
||||
|
||||
*buf = append(*buf, '\n')
|
||||
}
|
||||
|
||||
// Error logs a message at error level
|
||||
func (l *Logger) Error(format string, args ...any) {
|
||||
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
||||
bufp := l.bufPool.Get().(*[]byte)
|
||||
l.formatMessage(bufp, level, format, args...)
|
||||
|
||||
if len(*bufp) > maxMessageSize {
|
||||
*bufp = (*bufp)[:maxMessageSize]
|
||||
}
|
||||
_, _ = l.buffer.Write(*bufp)
|
||||
|
||||
l.bufPool.Put(bufp)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
l.log(LevelError, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Warn logs a message at warning level
|
||||
func (l *Logger) Warn(format string, args ...any) {
|
||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
l.log(LevelWarn, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs a message at info level
|
||||
func (l *Logger) Info(format string, args ...any) {
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelInfo) {
|
||||
l.log(LevelInfo, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Debug logs a message at debug level
|
||||
func (l *Logger) Debug(format string, args ...any) {
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
l.log(LevelDebug, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// Trace logs a message at trace level
|
||||
func (l *Logger) Trace(format string, args ...any) {
|
||||
func (l *Logger) Trace(format string, args ...interface{}) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
l.log(LevelTrace, format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
|
||||
*buf = (*buf)[:0]
|
||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||
*buf = append(*buf, ' ')
|
||||
*buf = append(*buf, levelStrings[level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
var msg string
|
||||
if len(args) > 0 {
|
||||
msg = fmt.Sprintf(format, args...)
|
||||
} else {
|
||||
msg = format
|
||||
}
|
||||
*buf = append(*buf, msg...)
|
||||
*buf = append(*buf, '\n')
|
||||
|
||||
if len(*buf) > maxMessageSize {
|
||||
*buf = (*buf)[:maxMessageSize]
|
||||
}
|
||||
}
|
||||
|
||||
// processMessage handles a single log message and adds it to the buffer
|
||||
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
|
||||
bufp := l.bufPool.Get().(*[]byte)
|
||||
defer l.bufPool.Put(bufp)
|
||||
|
||||
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
|
||||
|
||||
if len(*buffer)+len(*bufp) > maxBatchSize {
|
||||
_, _ = l.output.Write(*buffer)
|
||||
*buffer = (*buffer)[:0]
|
||||
}
|
||||
*buffer = append(*buffer, *bufp...)
|
||||
}
|
||||
|
||||
// flushBuffer writes the accumulated buffer to output
|
||||
func (l *Logger) flushBuffer(buffer *[]byte) {
|
||||
if len(*buffer) > 0 {
|
||||
_, _ = l.output.Write(*buffer)
|
||||
*buffer = (*buffer)[:0]
|
||||
}
|
||||
}
|
||||
|
||||
// processBatch processes as many messages as possible without blocking
|
||||
func (l *Logger) processBatch(buffer *[]byte) {
|
||||
for len(*buffer) < maxBatchSize {
|
||||
select {
|
||||
case msg := <-l.msgChannel:
|
||||
l.processMessage(msg, buffer)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleShutdown manages the graceful shutdown sequence with timeout
|
||||
func (l *Logger) handleShutdown(buffer *[]byte) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-l.msgChannel:
|
||||
l.processMessage(msg, buffer)
|
||||
case <-ctx.Done():
|
||||
l.flushBuffer(buffer)
|
||||
return
|
||||
}
|
||||
|
||||
if len(l.msgChannel) == 0 {
|
||||
l.flushBuffer(buffer)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// worker is the main goroutine that processes log messages
|
||||
// worker periodically flushes the buffer
|
||||
func (l *Logger) worker() {
|
||||
defer l.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(defaultFlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
buffer := make([]byte, 0, maxBatchSize)
|
||||
buf := make([]byte, 0, maxBatchSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.shutdown:
|
||||
l.handleShutdown(&buffer)
|
||||
return
|
||||
case <-ticker.C:
|
||||
l.flushBuffer(&buffer)
|
||||
case msg := <-l.msgChannel:
|
||||
l.processMessage(msg, &buffer)
|
||||
l.processBatch(&buffer)
|
||||
// Read accumulated messages
|
||||
n, _ := l.buffer.Read(buf[:cap(buf)])
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Write batch
|
||||
_, _ = l.output.Write(buf[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
package log_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
type discard struct{}
|
||||
|
||||
func (d *discard) Write(p []byte) (n int, err error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func BenchmarkLogger(b *testing.B) {
|
||||
simpleMessage := "Connection established"
|
||||
|
||||
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||
srcIP := "192.168.1.1"
|
||||
srcPort := uint16(12345)
|
||||
dstIP := "10.0.0.1"
|
||||
dstPort := uint16(443)
|
||||
state := 4 // TCPStateEstablished
|
||||
|
||||
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
|
||||
protocol := "TCP"
|
||||
direction := "outbound"
|
||||
flags := uint16(0x18) // ACK + PSH
|
||||
sequence := uint32(123456789)
|
||||
acknowledged := uint32(987654321)
|
||||
payloadSize := 1460
|
||||
fragmented := false
|
||||
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
|
||||
|
||||
b.Run("SimpleMessage", func(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Trace(simpleMessage)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConntrackMessage", func(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ComplexMessage", func(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkLoggerParallel tests the logger under concurrent load
|
||||
func BenchmarkLoggerParallel(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||
srcIP := "192.168.1.1"
|
||||
srcPort := uint16(12345)
|
||||
dstIP := "10.0.0.1"
|
||||
dstPort := uint16(443)
|
||||
state := 4
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
|
||||
func BenchmarkLoggerBurst(b *testing.B) {
|
||||
logger := createTestLogger()
|
||||
defer cleanupLogger(logger)
|
||||
|
||||
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||
srcIP := "192.168.1.1"
|
||||
srcPort := uint16(12345)
|
||||
dstIP := "10.0.0.1"
|
||||
dstPort := uint16(443)
|
||||
state := 4
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for j := 0; j < 100; j++ {
|
||||
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createTestLogger() *log.Logger {
|
||||
logrusLogger := logrus.New()
|
||||
logrusLogger.SetOutput(&discard{})
|
||||
logrusLogger.SetLevel(logrus.TraceLevel)
|
||||
return log.NewFromLogrus(logrusLogger)
|
||||
}
|
||||
|
||||
func cleanupLogger(logger *log.Logger) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = logger.Stop(ctx)
|
||||
}
|
||||
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
85
client/firewall/uspfilter/log/ringbuffer.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package log
|
||||
|
||||
import "sync"
|
||||
|
||||
// ringBuffer is a simple ring buffer implementation
|
||||
type ringBuffer struct {
|
||||
buf []byte
|
||||
size int
|
||||
r, w int64 // Read and write positions
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newRingBuffer(size int) *ringBuffer {
|
||||
return &ringBuffer{
|
||||
buf: make([]byte, size),
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if len(p) > r.size {
|
||||
p = p[:r.size]
|
||||
}
|
||||
|
||||
n = len(p)
|
||||
|
||||
// Write data, handling wrap-around
|
||||
pos := int(r.w % int64(r.size))
|
||||
writeLen := min(len(p), r.size-pos)
|
||||
copy(r.buf[pos:], p[:writeLen])
|
||||
|
||||
// If we have more data and need to wrap around
|
||||
if writeLen < len(p) {
|
||||
copy(r.buf, p[writeLen:])
|
||||
}
|
||||
|
||||
// Update write position
|
||||
r.w += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.w == r.r {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Calculate available data accounting for wraparound
|
||||
available := int(r.w - r.r)
|
||||
if available < 0 {
|
||||
available += r.size
|
||||
}
|
||||
available = min(available, r.size)
|
||||
|
||||
// Limit read to buffer size
|
||||
toRead := min(available, len(p))
|
||||
if toRead == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Read data, handling wrap-around
|
||||
pos := int(r.r % int64(r.size))
|
||||
readLen := min(toRead, r.size-pos)
|
||||
n = copy(p, r.buf[pos:pos+readLen])
|
||||
|
||||
// If we need more data and need to wrap around
|
||||
if readLen < toRead {
|
||||
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
||||
}
|
||||
|
||||
// Update read position
|
||||
r.r += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -11,36 +12,34 @@ import (
|
||||
// PeerRule to handle management of rules
|
||||
type PeerRule struct {
|
||||
id string
|
||||
mgmtId []byte
|
||||
ip netip.Addr
|
||||
ip net.IP
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
protoLayer gopacket.LayerType
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
comment string
|
||||
|
||||
udpHook func([]byte) bool
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
func (r *PeerRule) ID() string {
|
||||
// GetRuleID returns the rule id
|
||||
func (r *PeerRule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
|
||||
type RouteRule struct {
|
||||
id string
|
||||
mgmtId []byte
|
||||
sources []netip.Prefix
|
||||
dstSet firewall.Set
|
||||
destinations []netip.Prefix
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
id string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
func (r *RouteRule) ID() string {
|
||||
// GetRuleID returns the rule id
|
||||
func (r *RouteRule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
@@ -53,8 +53,8 @@ type TraceResult struct {
|
||||
}
|
||||
|
||||
type PacketTrace struct {
|
||||
SourceIP netip.Addr
|
||||
DestinationIP netip.Addr
|
||||
SourceIP net.IP
|
||||
DestinationIP net.IP
|
||||
Protocol string
|
||||
SourcePort uint16
|
||||
DestinationPort uint16
|
||||
@@ -72,8 +72,8 @@ type TCPState struct {
|
||||
}
|
||||
|
||||
type PacketBuilder struct {
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
SrcIP net.IP
|
||||
DstIP net.IP
|
||||
Protocol fw.Protocol
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||
SrcIP: p.SrcIP.AsSlice(),
|
||||
DstIP: p.DstIP.AsSlice(),
|
||||
SrcIP: p.SrcIP,
|
||||
DstIP: p.DstIP,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,30 +260,28 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(dstIP) {
|
||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if !m.handleRouting(trace) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.nativeRouter.Load() {
|
||||
if m.nativeRouter {
|
||||
return m.handleNativeRouter(trace)
|
||||
}
|
||||
|
||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||
}
|
||||
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
|
||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
||||
msg := "No existing connection found"
|
||||
if allowed {
|
||||
msg = m.buildConntrackStateMessage(d)
|
||||
@@ -311,46 +309,32 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||
return msg
|
||||
}
|
||||
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
||||
if !m.localForwarding {
|
||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
|
||||
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||
|
||||
strRuleId := "<no id>"
|
||||
if ruleId != nil {
|
||||
strRuleId = string(ruleId)
|
||||
}
|
||||
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
|
||||
msg := "Allowed by peer ACL rules"
|
||||
if blocked {
|
||||
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
|
||||
trace.AddResult(StagePeerACL, msg, false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
|
||||
return true
|
||||
msg = "Blocked by peer ACL rules"
|
||||
}
|
||||
trace.AddResult(StagePeerACL, msg, !blocked)
|
||||
|
||||
trace.AddResult(StagePeerACL, msg, true)
|
||||
|
||||
// Handle netstack mode
|
||||
if m.netstack {
|
||||
switch {
|
||||
case !m.localForwarding:
|
||||
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
|
||||
case m.forwarder.Load() != nil:
|
||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||
default:
|
||||
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
|
||||
}
|
||||
return true
|
||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
||||
}
|
||||
|
||||
// In normal mode, packets are allowed through for local delivery
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||
if !m.routingEnabled.Load() {
|
||||
if !m.routingEnabled {
|
||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||
return false
|
||||
@@ -366,23 +350,18 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||
proto, _ := getProtocolFromPacket(d)
|
||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
||||
proto := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||
|
||||
strId := string(id)
|
||||
if id == nil {
|
||||
strId = "<no id>"
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
|
||||
msg := "Allowed by route ACLs"
|
||||
if !allowed {
|
||||
msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
|
||||
msg = "Blocked by route ACLs"
|
||||
}
|
||||
trace.AddResult(StageRouteACL, msg, allowed)
|
||||
|
||||
if allowed && m.forwarder.Load() != nil {
|
||||
if allowed && m.forwarder != nil {
|
||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||
}
|
||||
|
||||
@@ -401,7 +380,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
dropped := m.processOutgoingHooks(packetData, 0)
|
||||
dropped := m.processOutgoingHooks(packetData)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
} else {
|
||||
|
||||
@@ -1,437 +0,0 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
|
||||
t.Logf("Trace results: %v", trace.Results)
|
||||
actualStages := make([]PacketStage, 0, len(trace.Results))
|
||||
for _, result := range trace.Results {
|
||||
actualStages = append(actualStages, result.Stage)
|
||||
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
|
||||
}
|
||||
|
||||
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
|
||||
}
|
||||
|
||||
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
|
||||
require.NotEmpty(t, trace.Results, "Trace should have results")
|
||||
lastResult := trace.Results[len(trace.Results)-1]
|
||||
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
|
||||
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
|
||||
}
|
||||
|
||||
func TestTracePacket(t *testing.T) {
|
||||
setupTracerTest := func(statefulMode bool) *Manager {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
if !statefulMode {
|
||||
m.stateful = false
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
|
||||
builder := &PacketBuilder{
|
||||
SrcIP: netip.MustParseAddr(srcIP),
|
||||
DstIP: netip.MustParseAddr(dstIP),
|
||||
Protocol: protocol,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
Direction: direction,
|
||||
}
|
||||
|
||||
if protocol == "tcp" {
|
||||
builder.TCPState = &TCPState{SYN: true}
|
||||
}
|
||||
|
||||
return builder
|
||||
}
|
||||
|
||||
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
|
||||
return &PacketBuilder{
|
||||
SrcIP: netip.MustParseAddr(srcIP),
|
||||
DstIP: netip.MustParseAddr(dstIP),
|
||||
Protocol: "icmp",
|
||||
ICMPType: icmpType,
|
||||
ICMPCode: icmpCode,
|
||||
Direction: direction,
|
||||
}
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setup func(*Manager)
|
||||
packetBuilder func() *PacketBuilder
|
||||
expectedStages []PacketStage
|
||||
expectedAllow bool
|
||||
}{
|
||||
{
|
||||
name: "LocalTraffic_ACLAllowed",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "LocalTraffic_ACLDenied",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "LocalTraffic_WithForwarder",
|
||||
setup: func(m *Manager) {
|
||||
m.netstack = true
|
||||
m.localForwarding = true
|
||||
|
||||
m.forwarder.Store(&forwarder.Forwarder{})
|
||||
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageForwarding,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "LocalTraffic_WithoutForwarder",
|
||||
setup: func(m *Manager) {
|
||||
m.netstack = true
|
||||
m.localForwarding = false
|
||||
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_ACLAllowed",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
m.forwarder.Store(&forwarder.Forwarder{})
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
StageForwarding,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_ACLDenied",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(false)
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_NativeRouter",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(true)
|
||||
m.nativeRouter.Store(true)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
StageForwarding,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "RoutedTraffic_RoutingDisabled",
|
||||
setup: func(m *Manager) {
|
||||
m.routingEnabled.Store(false)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "ConnectionTracking_Hit",
|
||||
setup: func(m *Manager) {
|
||||
srcIP := netip.MustParseAddr("100.10.0.100")
|
||||
dstIP := netip.MustParseAddr("1.1.1.1")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
|
||||
pb.TCPState = &TCPState{SYN: true, ACK: true}
|
||||
return pb
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "OutboundTraffic",
|
||||
setup: func(m *Manager) {
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "ICMPEchoRequest",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "ICMPDestinationUnreachable",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "UDPTraffic_WithoutHook",
|
||||
setup: func(m *Manager) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolUDP
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
},
|
||||
{
|
||||
name: "UDPTraffic_WithHook",
|
||||
setup: func(m *Manager) {
|
||||
hookFunc := func([]byte) bool {
|
||||
return true
|
||||
}
|
||||
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
{
|
||||
name: "StatefulDisabled_NoTracking",
|
||||
setup: func(m *Manager) {
|
||||
m.stateful = false
|
||||
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
m := setupTracerTest(true)
|
||||
|
||||
tc.setup(m)
|
||||
|
||||
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||
"100.10.0.100 should be recognized as a local IP")
|
||||
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")),
|
||||
"192.168.17.2 should not be recognized as a local IP")
|
||||
|
||||
pb := tc.packetBuilder()
|
||||
|
||||
trace, err := m.TracePacketFromBuilder(pb)
|
||||
require.NoError(t, err)
|
||||
|
||||
verifyTraceStages(t, trace, tc.expectedStages)
|
||||
verifyFinalDisposition(t, trace, tc.expectedAllow)
|
||||
})
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -93,7 +93,8 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: false,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Single rule allowing all traffic
|
||||
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
||||
fw.ActionAccept, "", "allow all")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||
@@ -113,15 +114,10 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Add explicit rules matching return traffic pattern
|
||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||
ip := generateRandomIPs(1)[0]
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
)
|
||||
fw.ActionAccept, "", "explicit return")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
},
|
||||
@@ -132,15 +128,8 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: true,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Add some basic rules but rely on state for established connections
|
||||
_, err := m.AddPeerFiltering(
|
||||
nil,
|
||||
net.ParseIP("0.0.0.0"),
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
)
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
||||
fw.ActionDrop, "", "default drop")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Connection tracking with established connections",
|
||||
@@ -169,11 +158,16 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Create manager and basic setup
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
// Apply scenario-specific setup
|
||||
sc.setupFunc(manager)
|
||||
|
||||
@@ -188,13 +182,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
|
||||
// For stateful scenarios, establish the connection
|
||||
if sc.stateful {
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound)
|
||||
}
|
||||
|
||||
// Measure inbound packet processing
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, 0)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -209,18 +203,23 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
// Pre-populate connection table
|
||||
srcIPs := generateRandomIPs(count)
|
||||
dstIPs := generateRandomIPs(count)
|
||||
for i := 0; i < count; i++ {
|
||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound)
|
||||
}
|
||||
|
||||
// Test packet
|
||||
@@ -228,11 +227,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
// First establish our test connection
|
||||
manager.processOutgoingHooks(testOut, 0)
|
||||
manager.processOutgoingHooks(testOut)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(testIn, 0)
|
||||
manager.dropFilter(testIn)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -252,23 +251,28 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
|
||||
srcIP := generateRandomIPs(1)[0]
|
||||
dstIP := generateRandomIPs(1)[0]
|
||||
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||
|
||||
if sc.established {
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, 0)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -289,6 +293,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -302,6 +310,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -316,6 +328,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -329,6 +345,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
}
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -342,6 +362,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -355,6 +379,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -369,6 +397,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolTCP,
|
||||
state: "post_handshake",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -383,6 +415,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "new",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -396,6 +432,10 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
proto: layers.IPProtocolUDP,
|
||||
state: "established",
|
||||
setupFunc: func(m *Manager) {
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.CIDRMask(0, 32),
|
||||
}
|
||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||
},
|
||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||
@@ -410,9 +450,9 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
// Setup scenario
|
||||
@@ -426,25 +466,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
// For stateful cases and established connections
|
||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||
manager.processOutgoingHooks(outbound, 0)
|
||||
manager.processOutgoingHooks(outbound)
|
||||
|
||||
// For TCP post-handshake, simulate full handshake
|
||||
if sc.state == "post_handshake" {
|
||||
// SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
manager.processOutgoingHooks(syn)
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, 0)
|
||||
manager.dropFilter(synack)
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
manager.processOutgoingHooks(ack)
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.dropFilter(inbound, 0)
|
||||
manager.dropFilter(inbound)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -537,15 +577,23 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -568,17 +616,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Initial SYN
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
manager.processOutgoingHooks(syn)
|
||||
|
||||
// SYN-ACK
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, 0)
|
||||
manager.dropFilter(synack)
|
||||
|
||||
// ACK
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
manager.processOutgoingHooks(ack)
|
||||
}
|
||||
|
||||
// Prepare test packets simulating bidirectional traffic
|
||||
@@ -599,9 +647,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
// First outbound data
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
// Then inbound response - this is what we're actually measuring
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -620,15 +668,23 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -700,19 +756,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Connection establishment
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
|
||||
// Data transfer
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response)
|
||||
|
||||
// Connection teardown
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -731,14 +787,22 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -760,15 +824,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
for i := 0; i < sc.connCount; i++ {
|
||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||
manager.processOutgoingHooks(syn, 0)
|
||||
manager.processOutgoingHooks(syn)
|
||||
|
||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||
manager.dropFilter(synack, 0)
|
||||
manager.dropFilter(synack)
|
||||
|
||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||
manager.processOutgoingHooks(ack, 0)
|
||||
manager.processOutgoingHooks(ack)
|
||||
}
|
||||
|
||||
// Pre-generate test packets
|
||||
@@ -790,8 +854,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
counter++
|
||||
|
||||
// Simulate bidirectional traffic
|
||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||
manager.dropFilter(inPackets[connIdx], 0)
|
||||
manager.processOutgoingHooks(outPackets[connIdx])
|
||||
manager.dropFilter(inPackets[connIdx])
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -811,13 +875,21 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
require.NoError(b, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.SetNetwork(&net.IPNet{
|
||||
IP: net.ParseIP("100.64.0.0"),
|
||||
Mask: net.CIDRMask(10, 32),
|
||||
})
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -879,17 +951,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
p := patterns[connIdx]
|
||||
|
||||
// Full connection lifecycle
|
||||
manager.processOutgoingHooks(p.syn, 0)
|
||||
manager.dropFilter(p.synAck, 0)
|
||||
manager.processOutgoingHooks(p.ack, 0)
|
||||
manager.processOutgoingHooks(p.syn)
|
||||
manager.dropFilter(p.synAck)
|
||||
manager.processOutgoingHooks(p.ack)
|
||||
|
||||
manager.processOutgoingHooks(p.request, 0)
|
||||
manager.dropFilter(p.response, 0)
|
||||
manager.processOutgoingHooks(p.request)
|
||||
manager.dropFilter(p.response)
|
||||
|
||||
manager.processOutgoingHooks(p.finClient, 0)
|
||||
manager.dropFilter(p.ackServer, 0)
|
||||
manager.dropFilter(p.finServer, 0)
|
||||
manager.processOutgoingHooks(p.ackClient, 0)
|
||||
manager.processOutgoingHooks(p.finClient)
|
||||
manager.dropFilter(p.ackServer)
|
||||
manager.dropFilter(p.finServer)
|
||||
manager.processOutgoingHooks(p.ackClient)
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -961,8 +1033,14 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
dst := fw.Network{Prefix: r.dest}
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||
_, err := manager.AddRouteFiltering(
|
||||
r.sources,
|
||||
r.dest,
|
||||
r.proto,
|
||||
nil,
|
||||
r.port,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@@ -984,8 +1062,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, tc := range cases {
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,33 +12,38 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
func TestPeerACLFiltering(t *testing.T) {
|
||||
localIP := netip.MustParseAddr("100.10.0.100")
|
||||
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||
localIP := net.ParseIP("100.10.0.100")
|
||||
wgNet := &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: localIP,
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, manager)
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
})
|
||||
|
||||
manager.wgNetwork = wgNet
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -183,321 +188,24 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Allow TCP traffic without port specification",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Allow UDP traffic without port specification",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "TCP packet doesn't match UDP filter with same port",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "UDP packet doesn't match TCP filter with same port",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "ICMP packet doesn't match TCP filter",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "ICMP packet doesn't match UDP filter",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Allow TCP traffic within port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Block TCP traffic outside port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 7999,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Edge Case - Port at Range Boundary",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8100,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "UDP Port Range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 5060,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Allow multiple destination ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Allow multiple source ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||
ruleAction: fw.ActionAccept,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
// New drop test cases
|
||||
{
|
||||
name: "Drop TCP traffic from WG peer",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop UDP traffic from WG peer",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolUDP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{53}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop ICMP traffic from WG peer",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolICMP,
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop all traffic from WG peer",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolALL,
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop traffic from multiple source ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop multiple destination ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Drop TCP traffic within port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 8080,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Accept TCP traffic outside drop port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 7999,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Drop TCP traffic with source port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 32100,
|
||||
dstPort: 80,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Mixed rule - drop specific port but allow other ports",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "100.10.0.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
ruleIP: "100.10.0.1",
|
||||
ruleProto: fw.ProtocolTCP,
|
||||
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||
ruleAction: fw.ActionDrop,
|
||||
shouldBeBlocked: true,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
isDropped := manager.DropIncoming(packet)
|
||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||
})
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
if tc.ruleAction == fw.ActionDrop {
|
||||
// add general accept rule to test drop rule
|
||||
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
nil,
|
||||
net.ParseIP("0.0.0.0"),
|
||||
fw.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
nil,
|
||||
net.ParseIP(tc.ruleIP),
|
||||
tc.ruleProto,
|
||||
tc.ruleSrcPort,
|
||||
tc.ruleDstPort,
|
||||
tc.ruleAction,
|
||||
"",
|
||||
tc.name,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
@@ -509,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
isDropped := manager.DropIncoming(packet, 0)
|
||||
isDropped := manager.DropIncoming(packet)
|
||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||
})
|
||||
}
|
||||
@@ -575,13 +283,14 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix(network)
|
||||
localIP, wgNet, err := net.ParseCIDR(network)
|
||||
require.NoError(tb, err)
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: localIP,
|
||||
Network: wgNet,
|
||||
}
|
||||
},
|
||||
@@ -593,15 +302,14 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false)
|
||||
require.NoError(tb, err)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NotNil(tb, manager)
|
||||
require.True(tb, manager.routingEnabled.Load())
|
||||
require.False(tb, manager.nativeRouter.Load())
|
||||
require.True(tb, manager.routingEnabled)
|
||||
require.False(tb, manager.nativeRouter)
|
||||
|
||||
tb.Cleanup(func() {
|
||||
require.NoError(tb, manager.Close(nil))
|
||||
require.NoError(tb, manager.Reset(nil))
|
||||
})
|
||||
|
||||
return manager
|
||||
@@ -612,7 +320,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
|
||||
type rule struct {
|
||||
sources []netip.Prefix
|
||||
dest fw.Network
|
||||
dest netip.Prefix
|
||||
proto fw.Protocol
|
||||
srcPort *fw.Port
|
||||
dstPort *fw.Port
|
||||
@@ -638,7 +346,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -654,7 +362,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -670,7 +378,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -686,7 +394,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 53,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolUDP,
|
||||
dstPort: &fw.Port{Values: []uint16{53}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -700,7 +408,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
proto: fw.ProtocolICMP,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: fw.ProtocolICMP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
@@ -715,7 +423,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolALL,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -731,7 +439,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8080,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -747,7 +455,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -763,7 +471,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -779,7 +487,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: &fw.Port{Values: []uint16{12345}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -798,7 +506,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
netip.MustParsePrefix("100.10.0.0/16"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -812,7 +520,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
proto: fw.ProtocolICMP,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolALL,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
@@ -827,13 +535,33 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolALL,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple source networks with mismatched protocol",
|
||||
srcIP: "172.16.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
// Should not match TCP rule
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("100.10.0.0/16"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "Allow multiple destination ports",
|
||||
srcIP: "100.10.0.1",
|
||||
@@ -843,7 +571,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8080,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -859,7 +587,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -875,7 +603,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolALL,
|
||||
srcPort: &fw.Port{Values: []uint16{12345}},
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
@@ -892,7 +620,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8080,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -911,7 +639,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 7999,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -930,7 +658,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -949,7 +677,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -971,7 +699,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8100,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -990,7 +718,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 5060,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolUDP,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -1009,7 +737,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 8080,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolALL,
|
||||
dstPort: &fw.Port{
|
||||
IsRange: true,
|
||||
@@ -1028,7 +756,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionDrop,
|
||||
@@ -1044,7 +772,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolALL,
|
||||
action: fw.ActionDrop,
|
||||
},
|
||||
@@ -1062,160 +790,18 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
netip.MustParsePrefix("100.10.0.0/16"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionDrop,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "Drop empty destination set",
|
||||
srcIP: "172.16.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
},
|
||||
dest: fw.Network{Set: fw.Set{}},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "Accept TCP traffic outside drop port range",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 7999,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||
action: fw.ActionDrop,
|
||||
},
|
||||
shouldPass: true,
|
||||
},
|
||||
{
|
||||
name: "Allow TCP traffic without port specification",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 443,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: true,
|
||||
},
|
||||
{
|
||||
name: "Allow UDP traffic without port specification",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 53,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: true,
|
||||
},
|
||||
{
|
||||
name: "TCP packet doesn't match UDP filter with same port",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolTCP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "UDP packet doesn't match TCP filter with same port",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolUDP,
|
||||
srcPort: 12345,
|
||||
dstPort: 80,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP packet doesn't match TCP filter",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolTCP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP packet doesn't match UDP filter",
|
||||
srcIP: "100.10.0.1",
|
||||
dstIP: "192.168.1.100",
|
||||
proto: fw.ProtocolICMP,
|
||||
rule: rule{
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
proto: fw.ProtocolUDP,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
shouldPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.rule.action == fw.ActionDrop {
|
||||
// add general accept rule to test drop rule
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
})
|
||||
}
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
tc.rule.sources,
|
||||
tc.rule.dest,
|
||||
tc.rule.proto,
|
||||
@@ -1230,12 +816,12 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
})
|
||||
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||
srcIP := net.ParseIP(tc.srcIP)
|
||||
dstIP := net.ParseIP(tc.dstIP)
|
||||
|
||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||
// to the forwarder
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
require.Equal(t, tc.shouldPass, isAllowed)
|
||||
})
|
||||
}
|
||||
@@ -1248,7 +834,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
name string
|
||||
rules []struct {
|
||||
sources []netip.Prefix
|
||||
dest fw.Network
|
||||
dest netip.Prefix
|
||||
proto fw.Protocol
|
||||
srcPort *fw.Port
|
||||
dstPort *fw.Port
|
||||
@@ -1269,7 +855,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
name: "Drop rules take precedence over accept",
|
||||
rules: []struct {
|
||||
sources []netip.Prefix
|
||||
dest fw.Network
|
||||
dest netip.Prefix
|
||||
proto fw.Protocol
|
||||
srcPort *fw.Port
|
||||
dstPort *fw.Port
|
||||
@@ -1278,7 +864,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
{
|
||||
// Accept rule added first
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80, 443}},
|
||||
action: fw.ActionAccept,
|
||||
@@ -1286,7 +872,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
{
|
||||
// Drop rule added second but should be evaluated first
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionDrop,
|
||||
@@ -1324,7 +910,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
name: "Multiple drop rules take precedence",
|
||||
rules: []struct {
|
||||
sources []netip.Prefix
|
||||
dest fw.Network
|
||||
dest netip.Prefix
|
||||
proto fw.Protocol
|
||||
srcPort *fw.Port
|
||||
dstPort *fw.Port
|
||||
@@ -1333,14 +919,14 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
{
|
||||
// Accept all
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: fw.ProtocolALL,
|
||||
action: fw.ActionAccept,
|
||||
},
|
||||
{
|
||||
// Drop specific port
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{443}},
|
||||
action: fw.ActionDrop,
|
||||
@@ -1348,7 +934,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
{
|
||||
// Drop different port
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
proto: fw.ProtocolTCP,
|
||||
dstPort: &fw.Port{Values: []uint16{80}},
|
||||
action: fw.ActionDrop,
|
||||
@@ -1398,7 +984,6 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
var rules []fw.Rule
|
||||
for _, r := range tc.rules {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
r.sources,
|
||||
r.dest,
|
||||
r.proto,
|
||||
@@ -1418,59 +1003,12 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
})
|
||||
|
||||
for i, p := range tc.packets {
|
||||
srcIP := netip.MustParseAddr(p.srcIP)
|
||||
dstIP := netip.MustParseAddr(p.dstIP)
|
||||
srcIP := net.ParseIP(p.srcIP)
|
||||
dstIP := net.ParseIP(p.dstIP)
|
||||
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteACLSet(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
// Add rule that uses the set (initially empty)
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||
|
||||
// Check that traffic is dropped (empty set shouldn't match anything)
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.False(t, isAllowed, "Empty set should not allow any traffic")
|
||||
|
||||
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now the packet should be allowed
|
||||
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package uspfilter
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -17,18 +16,15 @@ import (
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() wgaddr.Address
|
||||
AddressFunc func() iface.WGAddress
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
@@ -54,9 +50,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||
return i.SetFilterFunc(iface)
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Address() wgaddr.Address {
|
||||
func (i *IFaceMock) Address() iface.WGAddress {
|
||||
if i.AddressFunc == nil {
|
||||
return wgaddr.Address{}
|
||||
return iface.WGAddress{}
|
||||
}
|
||||
return i.AddressFunc()
|
||||
}
|
||||
@@ -66,7 +62,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -86,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -96,8 +92,9 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -119,25 +116,26 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := netip.MustParseAddr("192.168.1.1")
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule 2"
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
|
||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
@@ -151,7 +149,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
if _, ok := m.incomingRules[ip][r.ID()]; ok {
|
||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
@@ -162,7 +160,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name string
|
||||
in bool
|
||||
expDir fw.RuleDirection
|
||||
ip netip.Addr
|
||||
ip net.IP
|
||||
dPort uint16
|
||||
hook func([]byte) bool
|
||||
expectedID string
|
||||
@@ -171,7 +169,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name: "Test Outgoing UDP Packet Hook",
|
||||
in: false,
|
||||
expDir: fw.RuleDirectionOUT,
|
||||
ip: netip.MustParseAddr("10.168.0.1"),
|
||||
ip: net.IPv4(10, 168, 0, 1),
|
||||
dPort: 8000,
|
||||
hook: func([]byte) bool { return true },
|
||||
},
|
||||
@@ -179,7 +177,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
name: "Test Incoming UDP Packet Hook",
|
||||
in: true,
|
||||
expDir: fw.RuleDirectionIN,
|
||||
ip: netip.MustParseAddr("::1"),
|
||||
ip: net.IPv6loopback,
|
||||
dPort: 9000,
|
||||
hook: func([]byte) bool { return false },
|
||||
},
|
||||
@@ -189,18 +187,18 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
|
||||
var addedRule PeerRule
|
||||
if tt.in {
|
||||
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.incomingRules[tt.ip] {
|
||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
@@ -208,12 +206,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||
return
|
||||
}
|
||||
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
|
||||
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||
if !tt.ip.Equal(addedRule.ip) {
|
||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||
return
|
||||
}
|
||||
@@ -238,7 +236,7 @@ func TestManagerReset(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -248,14 +246,15 @@ func TestManagerReset(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = m.Close(nil)
|
||||
err = m.Reset(nil)
|
||||
if err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
return
|
||||
@@ -269,25 +268,33 @@ func TestManagerReset(t *testing.T) {
|
||||
func TestNotMatchByIP(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
AddressFunc: func() iface.WGAddress {
|
||||
return iface.WGAddress{
|
||||
IP: net.ParseIP("100.10.0.100"),
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
m.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -321,12 +328,12 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.dropFilter(buf.Bytes(), 0) {
|
||||
if m.dropFilter(buf.Bytes()) {
|
||||
t.Errorf("expected packet to be accepted")
|
||||
return
|
||||
}
|
||||
|
||||
if err = m.Close(nil); err != nil {
|
||||
if err = m.Reset(nil); err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -340,17 +347,17 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
manager, err := Create(iface, false, flowLogger)
|
||||
manager, err := Create(iface, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
}()
|
||||
|
||||
// Add a UDP packet hook
|
||||
hookFunc := func(data []byte) bool { return true }
|
||||
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
@@ -386,13 +393,17 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
manager.udpTracker.Close()
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
}()
|
||||
|
||||
manager.decoders = sync.Pool{
|
||||
@@ -412,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
hookCalled := false
|
||||
hookID := manager.AddUDPPacketHook(
|
||||
false,
|
||||
netip.MustParseAddr("100.10.0.100"),
|
||||
net.ParseIP("100.10.0.100"),
|
||||
53,
|
||||
func([]byte) bool {
|
||||
hookCalled = true
|
||||
@@ -447,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test hook gets called
|
||||
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
result := manager.processOutgoingHooks(buf.Bytes())
|
||||
require.True(t, result)
|
||||
require.True(t, hookCalled)
|
||||
|
||||
@@ -457,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||
result = manager.processOutgoingHooks(buf.Bytes())
|
||||
require.False(t, result)
|
||||
}
|
||||
|
||||
@@ -468,12 +479,12 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Close(nil); err != nil {
|
||||
if err := manager.Reset(nil); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
@@ -483,7 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
@@ -495,11 +506,16 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.wgNetwork = &net.IPNet{
|
||||
IP: net.ParseIP("100.10.0.0"),
|
||||
Mask: net.CIDRMask(16, 32),
|
||||
}
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
||||
manager.decoders = sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
@@ -514,12 +530,12 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
},
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
}()
|
||||
|
||||
// Set up packet parameters
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
dstIP := netip.MustParseAddr("100.10.0.100")
|
||||
srcIP := net.ParseIP("100.10.0.1")
|
||||
dstIP := net.ParseIP("100.10.0.100")
|
||||
srcPort := uint16(51334)
|
||||
dstPort := uint16(53)
|
||||
|
||||
@@ -527,8 +543,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
outboundIPv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: srcIP.AsSlice(),
|
||||
DstIP: dstIP.AsSlice(),
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
outboundUDP := &layers.UDP{
|
||||
@@ -553,15 +569,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Process outbound packet and verify connection tracking
|
||||
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||
drop := manager.DropOutgoing(outboundBuf.Bytes())
|
||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||
|
||||
// Verify connection was tracked
|
||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
||||
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
|
||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
|
||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
|
||||
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
||||
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
||||
|
||||
@@ -569,8 +585,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
inboundIPv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: dstIP.AsSlice(), // Original destination is now source
|
||||
DstIP: srcIP.AsSlice(), // Original source is now destination
|
||||
SrcIP: dstIP, // Original destination is now source
|
||||
DstIP: srcIP, // Original source is now destination
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
inboundUDP := &layers.UDP{
|
||||
@@ -620,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
for _, cp := range checkPoints {
|
||||
time.Sleep(cp.sleep)
|
||||
|
||||
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||
drop = manager.dropFilter(inboundBuf.Bytes())
|
||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||
|
||||
// If the connection should still be valid, verify it exists
|
||||
@@ -669,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create a new outbound connection for invalid tests
|
||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
|
||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||
|
||||
for _, tc := range invalidCases {
|
||||
@@ -691,208 +707,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invalid packet is dropped
|
||||
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||
drop = manager.dropFilter(testBuf.Bytes())
|
||||
require.True(t, drop, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSetMerge(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
initialPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
// Update the set with initial prefixes
|
||||
err = manager.UpdateSet(set, initialPrefixes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test initial prefixes work
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
dstIP1 := netip.MustParseAddr("10.0.0.100")
|
||||
dstIP2 := netip.MustParseAddr("192.168.1.100")
|
||||
dstIP3 := netip.MustParseAddr("172.16.0.100")
|
||||
|
||||
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
|
||||
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
|
||||
require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied")
|
||||
|
||||
newPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
}
|
||||
|
||||
err = manager.UpdateSet(set, newPrefixes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that all original prefixes are still included
|
||||
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
|
||||
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
|
||||
|
||||
// Check that new prefixes are included
|
||||
dstIP4 := netip.MustParseAddr("172.16.1.100")
|
||||
dstIP5 := netip.MustParseAddr("10.1.0.50")
|
||||
|
||||
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
|
||||
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
|
||||
|
||||
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
|
||||
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
|
||||
|
||||
// Verify the rule has all prefixes
|
||||
manager.mutex.RLock()
|
||||
foundRule := false
|
||||
for _, r := range manager.routeRules {
|
||||
if r.id == rule.ID() {
|
||||
foundRule = true
|
||||
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
|
||||
"Rule should have all prefixes merged")
|
||||
}
|
||||
}
|
||||
manager.mutex.RUnlock()
|
||||
require.True(t, foundRule, "Rule should be found")
|
||||
}
|
||||
|
||||
func TestUpdateSetDeduplication(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
|
||||
initialPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"), // Duplicate
|
||||
}
|
||||
|
||||
err = manager.UpdateSet(set, initialPrefixes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the internal state for deduplication
|
||||
manager.mutex.RLock()
|
||||
foundRule := false
|
||||
for _, r := range manager.routeRules {
|
||||
if r.id == rule.ID() {
|
||||
foundRule = true
|
||||
// Should have deduplicated to 2 prefixes
|
||||
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
|
||||
|
||||
// Check the prefixes are correct
|
||||
expectedPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
for i, prefix := range expectedPrefixes {
|
||||
require.True(t, r.destinations[i] == prefix,
|
||||
"Prefix should match expected value")
|
||||
}
|
||||
}
|
||||
}
|
||||
manager.mutex.RUnlock()
|
||||
require.True(t, foundRule, "Rule should be found")
|
||||
|
||||
// Test with overlapping prefixes of different sizes
|
||||
overlappingPrefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/16"), // More general
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists)
|
||||
netip.MustParsePrefix("192.168.0.0/16"), // More general
|
||||
netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists)
|
||||
}
|
||||
|
||||
err = manager.UpdateSet(set, overlappingPrefixes)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that all prefixes are included (no deduplication of overlapping prefixes)
|
||||
manager.mutex.RLock()
|
||||
for _, r := range manager.routeRules {
|
||||
if r.id == rule.ID() {
|
||||
// Should have all 4 prefixes (2 original + 2 new more general ones)
|
||||
require.Len(t, r.destinations, 4,
|
||||
"Overlapping prefixes should not be deduplicated")
|
||||
|
||||
// Verify they're sorted correctly (more specific prefixes should come first)
|
||||
prefixes := make([]string, 0, len(r.destinations))
|
||||
for _, p := range r.destinations {
|
||||
prefixes = append(prefixes, p.String())
|
||||
}
|
||||
|
||||
// Check sorted order
|
||||
require.Equal(t, []string{
|
||||
"10.0.0.0/16",
|
||||
"10.0.0.0/24",
|
||||
"192.168.0.0/16",
|
||||
"192.168.1.0/24",
|
||||
}, prefixes, "Prefixes should be sorted")
|
||||
}
|
||||
}
|
||||
manager.mutex.RUnlock()
|
||||
|
||||
// Test functionality with all prefixes
|
||||
testCases := []struct {
|
||||
dstIP netip.Addr
|
||||
expected bool
|
||||
desc string
|
||||
}{
|
||||
{netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"},
|
||||
{netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"},
|
||||
{netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"},
|
||||
{netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"},
|
||||
{netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"},
|
||||
}
|
||||
|
||||
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||
for _, tc := range testCases {
|
||||
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
|
||||
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/stun/v2"
|
||||
@@ -13,8 +14,6 @@ import (
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
@@ -53,10 +52,9 @@ type ICEBind struct {
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
@@ -66,7 +64,6 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
address: address,
|
||||
}
|
||||
|
||||
rc := receiverCreator{
|
||||
@@ -111,17 +108,35 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
return s.udpMux, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
|
||||
fakeUDPAddr, err := fakeAddress(peerAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// force IPv4
|
||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
|
||||
}
|
||||
|
||||
b.endpointsMu.Lock()
|
||||
b.endpoints[fakeIP] = conn
|
||||
b.endpoints[fakeAddr] = conn
|
||||
b.endpointsMu.Unlock()
|
||||
|
||||
return fakeUDPAddr, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
|
||||
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
|
||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
||||
if !ok {
|
||||
log.Warnf("failed to convert IP to netip.Addr")
|
||||
return
|
||||
}
|
||||
|
||||
b.endpointsMu.Lock()
|
||||
defer b.endpointsMu.Unlock()
|
||||
|
||||
delete(b.endpoints, fakeIP)
|
||||
delete(b.endpoints, fakeAddr)
|
||||
}
|
||||
|
||||
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
@@ -146,10 +161,9 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
},
|
||||
)
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
@@ -261,6 +275,21 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
||||
}
|
||||
}
|
||||
|
||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||
if len(octets) != 4 {
|
||||
return nil, fmt.Errorf("invalid IP format")
|
||||
}
|
||||
|
||||
newAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
|
||||
Port: peerAddress.Port,
|
||||
}
|
||||
return newAddr, nil
|
||||
}
|
||||
|
||||
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||
return msgsPool.Get().(*[]ipv6.Message)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -150,10 +149,49 @@ func isZeros(ip net.IP) bool {
|
||||
// NewUDPMuxDefault creates an implementation of UDPMux
|
||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
if params.Logger == nil {
|
||||
params.Logger = getLogger()
|
||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||
}
|
||||
|
||||
mux := &UDPMuxDefault{
|
||||
var localAddrsForUnspecified []net.Addr
|
||||
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
|
||||
} else if ok && addr.IP.IsUnspecified() {
|
||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||
// it will break the applications that are already using unspecified UDP connection
|
||||
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
||||
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
var networks []ice.NetworkType
|
||||
switch {
|
||||
|
||||
case addr.IP.To16() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
default:
|
||||
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
|
||||
}
|
||||
if len(networks) > 0 {
|
||||
if params.Net == nil {
|
||||
var err error
|
||||
if params.Net, err = stdnet.NewNet(); err != nil {
|
||||
params.Logger.Errorf("failed to get create network: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
||||
}
|
||||
} else {
|
||||
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &UDPMuxDefault{
|
||||
addressMap: map[string][]*udpMuxedConn{},
|
||||
params: params,
|
||||
connsIPv4: make(map[string]*udpMuxedConn),
|
||||
@@ -165,55 +203,8 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
return newBufferHolder(receiveMTU + maxAddrSize)
|
||||
},
|
||||
},
|
||||
localAddrsForUnspecified: localAddrsForUnspecified,
|
||||
}
|
||||
|
||||
mux.updateLocalAddresses()
|
||||
return mux
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) updateLocalAddresses() {
|
||||
var localAddrsForUnspecified []net.Addr
|
||||
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
|
||||
} else if ok && addr.IP.IsUnspecified() {
|
||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||
// it will break the applications that are already using unspecified UDP connection
|
||||
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
||||
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||
var networks []ice.NetworkType
|
||||
switch {
|
||||
|
||||
case addr.IP.To16() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||
|
||||
case addr.IP.To4() != nil:
|
||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||
|
||||
default:
|
||||
m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr())
|
||||
}
|
||||
if len(networks) > 0 {
|
||||
if m.params.Net == nil {
|
||||
var err error
|
||||
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
||||
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true)
|
||||
if err == nil {
|
||||
for _, ip := range ips {
|
||||
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
||||
}
|
||||
} else {
|
||||
m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.localAddrsForUnspecified = localAddrsForUnspecified
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// LocalAddr returns the listening address of this UDPMuxDefault
|
||||
@@ -223,12 +214,8 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
||||
|
||||
// GetListenAddresses returns the list of addresses that this mux is listening on
|
||||
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
m.updateLocalAddresses()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if len(m.localAddrsForUnspecified) > 0 {
|
||||
return slices.Clone(m.localAddrsForUnspecified)
|
||||
return m.localAddrsForUnspecified
|
||||
}
|
||||
|
||||
return []net.Addr{m.LocalAddr()}
|
||||
@@ -238,10 +225,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||
// creates the connection if an existing one can't be found
|
||||
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
||||
// don't check addr for mux using unspecified address
|
||||
m.mu.Lock()
|
||||
lenLocalAddrs := len(m.localAddrsForUnspecified)
|
||||
m.mu.Unlock()
|
||||
if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
|
||||
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
|
||||
return nil, fmt.Errorf("invalid address %s", addr.String())
|
||||
}
|
||||
|
||||
@@ -455,9 +439,3 @@ func newBufferHolder(size int) *bufferHolder {
|
||||
buf: make([]byte, size),
|
||||
}
|
||||
}
|
||||
|
||||
func getLogger() logging.LeveledLogger {
|
||||
fac := logging.NewDefaultLoggerFactory()
|
||||
//fac.Writer = log.StandardLogger().Writer()
|
||||
return fac.NewLogger("ice")
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user