Compare commits

..

3 Commits

Author SHA1 Message Date
Maycon Santos
1c5bb4357d try new file 2024-08-13 21:34:46 +02:00
Maycon Santos
2504a79c25 test fork 2024-08-13 21:24:38 +02:00
Maycon Santos
044ccfd05b test using ICO for linux 2024-08-13 20:17:31 +02:00
866 changed files with 27365 additions and 100447 deletions

View File

@@ -1,4 +1,4 @@
FROM golang:1.23-bullseye
FROM golang:1.21-bullseye
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends\

View File

@@ -7,7 +7,7 @@
"features": {
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
"ghcr.io/devcontainers/features/go:1": {
"version": "1.23"
"version": "1.21"
}
},
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",

3
.github/FUNDING.yml vendored
View File

@@ -1,3 +0,0 @@
# These are supported funding model platforms
github: [netbirdio]

View File

@@ -31,22 +31,14 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version`
**Is any other VPN software installed?**
**NetBird status -dA output:**
If yes, which one?
If applicable, add the `netbird status -dA' command output.
**Debug output**
**Do you face any client issues on desktop?**
To help us resolve the problem, please attach the following debug output
netbird status -dA
As well as the file created by
netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
**Screenshots**
@@ -55,10 +47,3 @@ If applicable, add screenshots to help explain your problem.
**Additional context**
Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -1,4 +1,4 @@
name: "Darwin"
name: Test Code Darwin
on:
push:
@@ -12,24 +12,24 @@ concurrency:
jobs:
test:
name: "Client / Unit"
strategy:
matrix:
store: ['sqlite']
runs-on: macos-latest
steps:
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
cache: false
go-version: "1.21.x"
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}
key: macos-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
macos-gotest-
macos-go-
- name: Install libpcap
@@ -42,5 +42,4 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...

View File

@@ -1,4 +1,5 @@
name: "FreeBSD"
name: Test Code FreeBSD
on:
push:
@@ -12,7 +13,6 @@ concurrency:
jobs:
test:
name: "Client / Unit"
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
@@ -24,7 +24,7 @@ jobs:
copyback: false
release: "14.1"
prepare: |
pkg install -y go pkgconf xorg
pkg install -y go
# -x - to print all executed commands
# -e - to faile on first error
@@ -33,12 +33,12 @@ jobs:
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...

View File

@@ -1,4 +1,4 @@
name: Linux
name: Test Code Linux
on:
push:
@@ -11,483 +11,29 @@ concurrency:
cancel-in-progress: true
jobs:
build-cache:
name: "Build Cache"
runs-on: ubuntu-22.04
outputs:
management: ${{ steps.filter.outputs.management }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
management:
- 'management/**'
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
id: cache
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: steps.cache.outputs.cache-hit != 'true'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Build client
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 go build .
- name: Build client 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 .
- name: Build management
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 go build .
- name: Build management 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 .
- name: Build signal
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 go build .
- name: Build signal 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 .
- name: Build relay
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 go build .
- name: Build relay 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
test:
name: "Client / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_relay:
name: "Relay / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_management:
name: "Management / Unit"
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/...
benchmark:
name: "Management / Benchmark"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./...
api_benchmark:
name: "Management / Benchmark (API)"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
api_integration_test:
name: "Management / Integration"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
steps:
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
cache: false
go-version: "1.21.x"
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache@v3
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -503,41 +49,26 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ]
runs-on: ubuntu-20.04
steps:
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
go-version: "1.21.x"
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache@v3
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v3
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -548,6 +79,9 @@ jobs:
- name: check git status
run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
- name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
@@ -564,7 +98,7 @@ jobs:
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
- run: chmod +x *testing.bin
@@ -572,7 +106,7 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
@@ -590,4 +124,4 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -1,4 +1,4 @@
name: "Windows"
name: Test Code Windows
on:
push:
@@ -14,34 +14,16 @@ concurrency:
jobs:
test:
name: "Client / Unit"
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
id: go
with:
go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-
${{ runner.os }}-go-
go-version: "1.21.x"
- name: Download wintun
uses: carlosperate/download-file-action@v2
@@ -60,13 +42,11 @@ jobs:
- run: choco install -y sysinternals --ignore-checksums
- run: choco install -y mingw
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -1,4 +1,4 @@
name: Lint
name: golangci-lint
on: [pull_request]
permissions:
@@ -15,11 +15,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
ignore_words_list: erro,clienta,hastable,
skip: go.mod,go.sum
only_warn: 1
golangci:
@@ -27,33 +27,26 @@ jobs:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
include:
- os: macos-latest
display_name: Darwin
- os: windows-latest
display_name: Windows
- os: ubuntu-latest
display_name: Linux
name: ${{ matrix.display_name }}
name: lint
runs-on: ${{ matrix.os }}
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
go-version: "1.21.x"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v4
uses: golangci/golangci-lint-action@v3
with:
version: latest
args: --timeout=12m --out-format colored-line-number
args: --timeout=12m

View File

@@ -13,7 +13,6 @@ concurrency:
jobs:
test-install-script:
strategy:
fail-fast: false
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest]
@@ -22,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: run install script
env:

View File

@@ -1,4 +1,4 @@
name: Mobile
name: Mobile build validation
on:
push:
@@ -12,27 +12,26 @@ concurrency:
jobs:
android_build:
name: "Android / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
go-version: "1.21.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@v4
uses: actions/setup-java@v3
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@@ -48,15 +47,14 @@ jobs:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
ios_build:
name: "iOS / Build"
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
go-version: "1.21.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
@@ -64,4 +62,4 @@ jobs:
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
env:
CGO_ENABLED: 0
CGO_ENABLED: 0

View File

@@ -3,16 +3,18 @@ name: Release
on:
push:
tags:
- "v*"
- 'v*'
branches:
- main
- validate-icon
pull_request:
env:
SIGN_PIPE_VER: "v0.0.18"
GORELEASER_VER: "v2.3.2"
SIGN_PIPE_VER: "v0.0.12"
GORELEASER_VER: "v1.14.1"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -20,7 +22,7 @@ concurrency:
jobs:
release:
runs-on: ubuntu-22.04
runs-on: ubuntu-latest
env:
flags: ""
steps:
@@ -33,17 +35,20 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
-
name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v5
-
name: Set up Go
uses: actions/setup-go@v4
with:
go-version: "1.23"
go-version: "1.21"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
-
name: Cache Go modules
uses: actions/cache@v3
with:
path: |
~/go/pkg/mod
@@ -51,19 +56,24 @@ jobs:
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-releaser-
- name: Install modules
-
name: Install modules
run: go mod tidy
- name: check git status
-
name: check git status
run: git --no-pager diff --exit-code
- name: Set up QEMU
-
name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
-
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Login to Docker hub
-
name: Login to Docker hub
if: github.event_name != 'pull_request'
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
username: netbirdio
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
@@ -71,36 +81,40 @@ jobs:
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }}
args: release --rm-dist ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:
name: release
path: dist/
retention-days: 3
- name: upload linux packages
-
name: upload linux packages
uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 3
- name: upload windows packages
-
name: upload windows packages
uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 3
- name: upload macos packages
-
name: upload macos packages
uses: actions/upload-artifact@v4
with:
name: macos-packages
@@ -120,19 +134,19 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v3
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23"
go-version: "1.21"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: |
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
@@ -150,13 +164,13 @@ jobs:
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
@@ -174,17 +188,20 @@ jobs:
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
-
name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v5
-
name: Set up Go
uses: actions/setup-go@v4
with:
go-version: "1.23"
go-version: "1.21"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
-
name: Cache Go modules
uses: actions/cache@v3
with:
path: |
~/go/pkg/mod
@@ -192,35 +209,53 @@ jobs:
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-ui-go-releaser-darwin-
- name: Install modules
-
name: Install modules
run: go mod tidy
- name: check git status
-
name: check git status
run: git --no-pager diff --exit-code
- name: Run GoReleaser
-
name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes
-
name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:
name: release-ui-darwin
path: dist/
retention-days: 3
trigger_signer:
trigger_windows_signer:
runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin]
needs: [release,release_ui]
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger binaries sign pipelines
- name: Trigger Windows binaries sign pipeline
uses: benc-uk/workflow-dispatch@v1
with:
workflow: Sign bin and installer
workflow: Sign windows bin and installer
repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
inputs: '{ "tag": "${{ github.ref }}" }'
trigger_darwin_signer:
runs-on: ubuntu-latest
needs: [release,release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger Darwin App binaries sign pipeline
uses: benc-uk/workflow-dispatch@v1
with:
workflow: Sign darwin ui app with dispatch
repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'

View File

@@ -18,49 +18,7 @@ concurrency:
jobs:
test-docker-compose:
runs-on: ubuntu-latest
strategy:
matrix:
store: [ 'sqlite', 'postgres', 'mysql' ]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
env:
POSTGRES_USER: netbird
POSTGRES_PASSWORD: postgres
POSTGRES_DB: netbird
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
ports:
- 5432:5432
mysql:
image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }}
env:
MYSQL_USER: netbird
MYSQL_PASSWORD: mysql
MYSQL_ROOT_PASSWORD: mysqlroot
MYSQL_DATABASE: netbird
options: >-
--health-cmd "mysqladmin ping --silent"
--health-interval 10s
--health-timeout 5s
ports:
- 3306:3306
steps:
- name: Set Database Connection String
run: |
if [ "${{ matrix.store }}" == "postgres" ]; then
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
fi
if [ "${{ matrix.store }}" == "mysql" ]; then
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV
fi
- name: Install jq
run: sudo apt-get install -y jq
@@ -68,12 +26,12 @@ jobs:
run: sudo apt-get install -y curl
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@v4
with:
go-version: "1.23.x"
go-version: "1.21.x"
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@v3
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -81,7 +39,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -100,9 +58,7 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values
@@ -129,9 +85,7 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
@@ -169,15 +123,6 @@ jobs:
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
- name: Install modules
run: go mod tidy
@@ -203,15 +148,6 @@ jobs:
run: |
docker build -t netbirdio/signal:latest .
- name: Build relay binary
working-directory: relay
run: CGO_ENABLED=0 go build -o netbird-relay main.go
- name: Build relay docker image
working-directory: relay
run: |
docker build -t netbirdio/relay:latest .
- name: run docker compose up
working-directory: infrastructure_files/artifacts
run: |
@@ -223,15 +159,15 @@ jobs:
- name: test running containers
run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
test $count -eq 5 || docker compose logs
test $count -eq 4
working-directory: infrastructure_files/artifacts
- name: test geolocation databases
working-directory: infrastructure_files/artifacts
run: |
sleep 30
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
test-getting-started-script:
runs-on: ubuntu-latest
@@ -240,7 +176,7 @@ jobs:
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
@@ -255,7 +191,7 @@ jobs:
run: test -f management.json
- name: test turnserver.conf file gen postgres
run: |
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
@@ -266,9 +202,6 @@ jobs:
- name: test dashboard.env file gen postgres
run: test -f dashboard.env
- name: test relay.env file gen postgres
run: test -f relay.env
- name: test zdb.env file gen postgres
run: test -f zdb.env
@@ -293,7 +226,7 @@ jobs:
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
@@ -304,5 +237,20 @@ jobs:
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
- name: test relay.env file gen CockroachDB
run: test -f relay.env
test-download-geolite2-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
- name: Checkout code
uses: actions/checkout@v3
- name: test script
run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists
run: test -f GeoLite2-City.mmdb
- name: test geonames file exists
run: test -f geonames.db

2
.gitignore vendored
View File

@@ -29,4 +29,4 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode
.DS_Store
vendor/
GeoLite2-City*

View File

@@ -103,7 +103,7 @@ linters:
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements
issues:
# Maximum count of issues with the same text.

View File

@@ -1,5 +1,3 @@
version: 2
project_name: netbird
builds:
- id: netbird
@@ -24,7 +22,7 @@ builds:
goarch: 386
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
mod_timestamp: '{{ .CommitTimestamp }}'
tags:
- load_wgnt_from_rsrc
@@ -44,19 +42,19 @@ builds:
- softfloat
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
mod_timestamp: '{{ .CommitTimestamp }}'
tags:
- load_wgnt_from_rsrc
- id: netbird-mgmt
dir: management
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-mgmt
goos:
- linux
@@ -66,7 +64,7 @@ builds:
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-signal
dir: signal
@@ -80,24 +78,7 @@ builds:
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-relay
dir: relay
env: [CGO_ENABLED=0]
binary: netbird-relay
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
mod_timestamp: '{{ .CommitTimestamp }}'
archives:
- builds:
@@ -105,6 +86,7 @@ archives:
- netbird-static
nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
@@ -179,97 +161,6 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-amd64
ids:
- netbird
goarch: amd64
use: buildx
dockerfile: client/Dockerfile-rootless
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
ids:
- netbird
goarch: arm64
use: buildx
dockerfile: client/Dockerfile-rootless
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm
ids:
- netbird
goarch: arm
goarm: 6
use: buildx
dockerfile: client/Dockerfile-rootless
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-amd64
ids:
- netbird-relay
goarch: amd64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
ids:
- netbird-relay
goarch: arm64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm
ids:
- netbird-relay
goarch: arm
goarm: 6
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-amd64
ids:
@@ -422,30 +313,6 @@ docker_manifests:
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/netbird:rootless-latest
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/relay:{{ .Version }}
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/relay:latest
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/signal:{{ .Version }}
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
@@ -477,9 +344,10 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-amd64
brews:
- ids:
-
ids:
- default
repository:
tap:
owner: netbirdio
name: homebrew-tap
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
@@ -496,7 +364,7 @@ brews:
uploads:
- name: debian
ids:
- netbird-deb
- netbird-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
@@ -518,4 +386,4 @@ checksum:
release:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
- glob: ./release_files/install.sh

View File

@@ -1,5 +1,3 @@
version: 2
project_name: netbird-ui
builds:
- id: netbird-ui
@@ -13,7 +11,7 @@ builds:
- amd64
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-ui-windows
dir: client/ui
@@ -28,7 +26,7 @@ builds:
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -H windowsgui
mod_timestamp: "{{ .CommitTimestamp }}"
mod_timestamp: '{{ .CommitTimestamp }}'
archives:
- id: linux-arch
@@ -41,6 +39,7 @@ archives:
- netbird-ui-windows
nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
homepage: https://netbird.io/
@@ -50,12 +49,10 @@ nfpms:
- netbird-ui
formats:
- deb
scripts:
postinstall: "release_files/ui-post-install.sh"
contents:
- src: client/ui/build/netbird.desktop
- src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/assets/netbird.png
- src: client/ui/netbird-systemtray-connected.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
@@ -69,12 +66,10 @@ nfpms:
- netbird-ui
formats:
- rpm
scripts:
postinstall: "release_files/ui-post-install.sh"
contents:
- src: client/ui/build/netbird.desktop
- src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop
- src: client/ui/assets/netbird.png
- src: client/ui/netbird-systemtray-connected.png
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
@@ -82,7 +77,7 @@ nfpms:
uploads:
- name: debian
ids:
- netbird-ui-deb
- netbird-ui-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com

View File

@@ -1,5 +1,3 @@
version: 2
project_name: netbird-ui
builds:
- id: netbird-ui-darwin
@@ -19,13 +17,10 @@ builds:
- softfloat
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
mod_timestamp: '{{ .CommitTimestamp }}'
tags:
- load_wgnt_from_rsrc
universal_binaries:
- id: netbird-ui-darwin
archives:
- builds:
- netbird-ui-darwin
@@ -33,4 +28,4 @@ archives:
checksum:
name_template: "{{ .ProjectName }}_darwin_checksums.txt"
changelog:
disable: true
skip: true

View File

@@ -1,3 +1,3 @@
Mikhail Bragin (https://github.com/braginini)
Maycon Santos (https://github.com/mlsmaycon)
NetBird GmbH
Wiretrustee UG (haftungsbeschränkt)

View File

@@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
**Goreleaser**
```shell
goreleaser build --snapshot --clean
goreleaser --snapshot --rm-dist
```
**golangci-lint**
```shell

View File

@@ -3,10 +3,10 @@
We are incredibly thankful for the contributions we receive from the community.
We require our external contributors to sign a Contributor License Agreement ("CLA") in
order to ensure that our projects remain licensed under Free and Open Source licenses such
as BSD-3 while allowing NetBird to build a sustainable business.
as BSD-3 while allowing Wiretrustee to build a sustainable business.
NetBird is committed to having a true Open Source Software ("OSS") license for
our software. A CLA enables NetBird to safely commercialize our products
Wiretrustee is committed to having a true Open Source Software ("OSS") license for
our software. A CLA enables Wiretrustee to safely commercialize our products
while keeping a standard OSS license with all the rights that license grants to users: the
ability to use the project in their own projects or businesses, to republish modified
source, or to completely fork the project.
@@ -20,11 +20,11 @@ This is a human-readable summary of (and not a substitute for) the full agreemen
This highlights only some of key terms of the CLA. It has no legal value and you should
carefully review all the terms of the actual CLA before agreeing.
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
<li>Grant of copyright license. You give Wiretrustee permission to use your copyrighted work
in commercial products.
</li>
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
<li>Grant of patent license. If your contributed work uses a patent, you give Wiretrustee a
license to use that patent including within commercial products. You also agree that you
have permission to grant this license.
</li>
@@ -45,7 +45,7 @@ more.
# Why require a CLA?
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
to use your contribution at a later date, and that Wiretrustee has permission to use your contribution in our commercial
products.
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
@@ -65,25 +65,25 @@ Follow the steps given by the bot to sign the CLA. This will require you to log
information from your account) and to fill in a few additional details such as your name and email address. We will only
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any Wiretrustee project will not
require you to sign again.
# Legal Terms and Agreement
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
In order to clarify the intellectual property license granted with Contributions from any person or entity, Wiretrustee
UG (haftungsbeschränkt) ("Wiretrustee") must have a Contributor License Agreement ("CLA") on file that has been signed
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
your own Contributions for any other purpose.
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
Wiretrustee. Except for the license granted herein to Wiretrustee and recipients of software distributed by Wiretrustee,
You reserve all right, title, and interest in and to Your Contributions.
1. Definitions.
```
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
that is making this Agreement with Wiretrustee. For legal entities, the entity making a Contribution and all other
entities that control, are controlled by, or are under common control with that entity are considered
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
@@ -91,23 +91,23 @@ You reserve all right, title, and interest in and to Your Contributions.
```
```
"Contribution" shall mean any original work of authorship, including any modifications or additions to
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
or documentation of, any of the products owned or managed by NetBird (the "Work").
an existing work, that is or previously has been intentionally submitted by You to Wiretrustee for inclusion in,
or documentation of, any of the products owned or managed by Wiretrustee (the "Work").
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
sent to Wiretrustee or its representatives, including but not limited to communication on electronic mailing lists,
source code control systems, and issue tracking systems that are managed by, or on behalf of,
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
Wiretrustee for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
marked or otherwise designated in writing by You as "Not a Contribution."
```
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee
and to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge,
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
perform, sublicense, and distribute Your Contributions and such derivative works.
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee and
to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
@@ -121,8 +121,8 @@ You reserve all right, title, and interest in and to Your Contributions.
intellectual property that you create that includes your Contributions, you represent that you have received
permission to make Contributions on behalf of that employer, that you will have received permission from your current
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
with NetBird.
your current and future Contributions to Wiretrustee, or that your employer has executed a separate Corporate CLA
with Wiretrustee.
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
@@ -138,11 +138,11 @@ You reserve all right, title, and interest in and to Your Contributions.
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
7. Should You wish to submit work that is not Your original creation, You may submit it to Wiretrustee separately from
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
representations inaccurate in any respect.
8. You agree to notify Wiretrustee of any facts or circumstances of which you become aware that would make these
representations inaccurate in any respect.

View File

@@ -1,6 +1,6 @@
BSD 3-Clause License
Copyright (c) 2022 NetBird GmbH & AUTHORS
Copyright (c) 2022 Wiretrustee UG (haftungsbeschränkt) & AUTHORS
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
@@ -10,4 +10,4 @@ Redistribution and use in source and binary forms, with or without modification,
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,6 +1,11 @@
<div align="center">
<p align="center">
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
Learn more
</a>
</p>
<br/>
<br/>
<div align="center">
<p align="center">
<img width="234" src="docs/media/logo-full.png"/>
</p>
@@ -12,12 +17,8 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a>
</p>
</div>
@@ -29,14 +30,10 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">Slack channel</a>
<br/>
</strong>
<br>
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
</a>
</p>
<br>
@@ -52,8 +49,6 @@
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features
@@ -67,7 +62,6 @@
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
| | | | | <ul><li> - \[x] Docker </ul></li> |
### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)

View File

@@ -1,4 +1,4 @@
FROM alpine:3.21.3
FROM alpine:3.19
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -1,17 +0,0 @@
FROM alpine:3.21.0
COPY netbird /usr/local/bin/netbird
RUN apk add --no-cache ca-certificates \
&& adduser -D -h /var/lib/netbird netbird
WORKDIR /var/lib/netbird
USER netbird:netbird
ENV NB_FOREGROUND_MODE=true
ENV NB_USE_NETSTACK_MODE=true
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
ENV NB_CONFIG=config.json
ENV NB_DAEMON_ADDR=unix://netbird.sock
ENV NB_DISABLE_DNS=true
ENTRYPOINT [ "/usr/local/bin/netbird", "up" ]

View File

@@ -8,7 +8,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
@@ -16,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util/net"
)
@@ -26,7 +26,7 @@ type ConnectionListener interface {
// TunAdapter export internal TunAdapter for mobile
type TunAdapter interface {
device.TunAdapter
iface.TunAdapter
}
// IFaceDiscover export internal IFaceDiscover for mobile
@@ -51,7 +51,7 @@ func init() {
// Client struct manage the life circle of background service
type Client struct {
cfgFile string
tunAdapter device.TunAdapter
tunAdapter iface.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
ctxCancel context.CancelFunc

View File

@@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
return
})
if err != nil {

View File

@@ -12,8 +12,6 @@ import (
"strings"
)
const anonTLD = ".domain"
type Anonymizer struct {
ipAnonymizer map[netip.Addr]netip.Addr
domainAnonymizer map[string]string
@@ -21,8 +19,6 @@ type Anonymizer struct {
currentAnonIPv6 netip.Addr
startAnonIPv4 netip.Addr
startAnonIPv6 netip.Addr
domainKeyRegex *regexp.Regexp
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
@@ -38,8 +34,6 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
currentAnonIPv6: startIPv6,
startAnonIPv4: startIPv4,
startAnonIPv6: startIPv6,
domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`),
}
}
@@ -89,39 +83,29 @@ func (a *Anonymizer) AnonymizeIPString(ip string) string {
}
func (a *Anonymizer) AnonymizeDomain(domain string) string {
baseDomain := domain
hasDot := strings.HasSuffix(domain, ".")
if hasDot {
baseDomain = domain[:len(domain)-1]
}
if strings.HasSuffix(baseDomain, "netbird.io") ||
strings.HasSuffix(baseDomain, "netbird.selfhosted") ||
strings.HasSuffix(baseDomain, "netbird.cloud") ||
strings.HasSuffix(baseDomain, "netbird.stage") ||
strings.HasSuffix(baseDomain, anonTLD) {
if strings.HasSuffix(domain, "netbird.io") ||
strings.HasSuffix(domain, "netbird.selfhosted") ||
strings.HasSuffix(domain, "netbird.cloud") ||
strings.HasSuffix(domain, "netbird.stage") ||
strings.HasSuffix(domain, ".domain") {
return domain
}
parts := strings.Split(baseDomain, ".")
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
baseForLookup := parts[len(parts)-2] + "." + parts[len(parts)-1]
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
anonymized, ok := a.domainAnonymizer[baseForLookup]
anonymized, ok := a.domainAnonymizer[baseDomain]
if !ok {
anonymizedBase := "anon-" + generateRandomString(5) + anonTLD
a.domainAnonymizer[baseForLookup] = anonymizedBase
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
a.domainAnonymizer[baseDomain] = anonymizedBase
anonymized = anonymizedBase
}
result := strings.Replace(baseDomain, baseForLookup, anonymized, 1)
if hasDot {
result += "."
}
return result
return strings.Replace(domain, baseDomain, anonymized, 1)
}
func (a *Anonymizer) AnonymizeURI(uri string) string {
@@ -168,22 +152,27 @@ func (a *Anonymizer) AnonymizeString(str string) string {
return str
}
// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes.
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`)
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
}
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.SplitN(match, "=", 2)
domainPattern := `dns\.Question{Name:"([^"]+)",`
domainRegex := regexp.MustCompile(domainPattern)
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.Split(match, `"`)
if len(parts) >= 2 {
domain := parts[1]
if strings.HasSuffix(domain, anonTLD) {
if strings.HasSuffix(domain, ".domain") {
return match
}
return "domain=" + a.AnonymizeDomain(domain)
randomDomain := generateRandomString(10) + ".domain"
return strings.Replace(match, domain, randomDomain, 1)
}
return match
})
@@ -212,8 +201,6 @@ func isWellKnown(addr netip.Addr) bool {
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
"128.0.0.0", "8000::", // 2nd split subnet for default routes
}
if slices.Contains(wellKnown, addr.String()) {

View File

@@ -46,59 +46,11 @@ func TestAnonymizeIP(t *testing.T) {
func TestAnonymizeDNSLogLine(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
input string
original string
expect string
}{
{
name: "Basic domain with trailing content",
input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123",
original: "example.com",
expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`,
},
{
name: "Domain with trailing dot",
input: "domain=example.com. processing request with status=pending",
original: "example.com",
expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`,
},
{
name: "Multiple domains in log",
input: "forward domain=first.com status=ok, redirect to domain=second.com port=443",
original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately
expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`,
},
{
name: "Already anonymized domain",
input: "got request domain=anon-xyz123.domain from=client1 to=server2",
original: "", // nothing should be anonymized
expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`,
},
{
name: "Subdomain with trailing dot",
input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp",
original: "example.com",
expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`,
},
{
name: "Handler chain pattern log",
input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100",
original: "example.com",
expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`,
},
}
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeDNSLogLine(tc.input)
if tc.original != "" {
assert.NotContains(t, result, tc.original)
}
assert.Regexp(t, tc.expect, result)
})
}
result := anonymizer.AnonymizeDNSLogLine(testLog)
require.NotEqual(t, testLog, result)
assert.NotContains(t, result, "example.com")
}
func TestAnonymizeDomain(t *testing.T) {
@@ -115,36 +67,18 @@ func TestAnonymizeDomain(t *testing.T) {
`^anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Domain with Trailing Dot",
"example.com.",
`^anon-[a-zA-Z0-9]+\.domain.$`,
true,
},
{
"Subdomain",
"sub.example.com",
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Subdomain with Trailing Dot",
"sub.example.com.",
`^sub\.anon-[a-zA-Z0-9]+\.domain.$`,
true,
},
{
"Protected Domain",
"netbird.io",
`^netbird\.io$`,
false,
},
{
"Protected Domain with Trailing Dot",
"netbird.io.",
`^netbird\.io.$`,
false,
},
}
for _, tc := range tests {
@@ -206,16 +140,8 @@ func TestAnonymizeSchemeURI(t *testing.T) {
expect string
}{
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
{"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`},
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
{"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`},
{"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`},
{"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
{"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`},
{"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`},
{"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`},
{"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`},
}
for _, tc := range tests {

View File

@@ -3,7 +3,6 @@ package cmd
import (
"context"
"fmt"
"strings"
"time"
log "github.com/sirupsen/logrus"
@@ -13,7 +12,6 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
)
const errCloseConnection = "Failed to close connection: %v"
@@ -63,15 +61,6 @@ var forCmd = &cobra.Command{
RunE: runForDuration,
}
var persistenceCmd = &cobra.Command{
Use: "persistence [on|off]",
Short: "Set network map memory persistence",
Long: `Configure whether the latest network map should persist in memory. When enabled, the last known network map will be kept in memory.`,
Example: " netbird debug persistence on",
Args: cobra.ExactArgs(1),
RunE: setNetworkMapPersistence,
}
func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
@@ -86,7 +75,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag),
Status: getStatusOutput(cmd),
SystemInfo: debugSystemInfoFlag,
})
if err != nil {
@@ -182,13 +171,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(1 * time.Second)
// Enable network map persistence before bringing the service up
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("failed to enable network map persistence: %v", status.Convert(err).Message())
}
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
@@ -197,7 +179,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
@@ -207,7 +189,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
@@ -218,13 +200,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
// Disable network map persistence after creating the debug bundle
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: false,
}); err != nil {
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
}
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
@@ -244,43 +219,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return nil
}
func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
persistence := strings.ToLower(args[0])
if persistence != "on" && persistence != "off" {
return fmt.Errorf("invalid persistence value: %s. Use 'on' or 'off'", args[0])
}
client := proto.NewDaemonServiceClient(conn)
_, err = client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: persistence == "on",
})
if err != nil {
return fmt.Errorf("failed to set network map persistence: %v", status.Convert(err).Message())
}
cmd.Printf("Network map persistence set to: %s\n", persistence)
return nil
}
func getStatusOutput(cmd *cobra.Command, anon bool) string {
func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
)
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
}
return statusOutputString
}

View File

@@ -42,8 +42,6 @@ var downCmd = &cobra.Command{
log.Errorf("call service down method: %v", err)
return err
}
cmd.Println("Disconnected")
return nil
},
}

View File

@@ -1,98 +0,0 @@
package cmd
import (
"fmt"
"sort"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var forwardingRulesCmd = &cobra.Command{
Use: "forwarding",
Short: "List forwarding rules",
Long: `Commands to list forwarding rules.`,
}
var forwardingRulesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List forwarding rules",
Example: " netbird forwarding list",
Long: "Commands to list forwarding rules.",
RunE: listForwardingRules,
}
func listForwardingRules(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ForwardingRules(cmd.Context(), &proto.EmptyRequest{})
if err != nil {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
if len(resp.GetRules()) == 0 {
cmd.Println("No forwarding rules available.")
return nil
}
printForwardingRules(cmd, resp.GetRules())
return nil
}
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
cmd.Println("Available forwarding rules:")
// Sort rules by translated address
sort.Slice(rules, func(i, j int) bool {
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
}
if rules[i].GetProtocol() != rules[j].GetProtocol() {
return rules[i].GetProtocol() < rules[j].GetProtocol()
}
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
})
var lastIP string
for _, rule := range rules {
dPort := portToString(rule.GetDestinationPort())
tPort := portToString(rule.GetTranslatedPort())
if lastIP != rule.GetTranslatedAddress() {
lastIP = rule.GetTranslatedAddress()
cmd.Printf("\nTranslated peer: %s\n", rule.GetTranslatedHostname())
}
cmd.Printf(" Local %s/%s to %s:%s\n", rule.GetProtocol(), dPort, rule.GetTranslatedAddress(), tPort)
}
}
func getFirstPort(portInfo *proto.PortInfo) int {
switch v := portInfo.PortSelection.(type) {
case *proto.PortInfo_Port:
return int(v.Port)
case *proto.PortInfo_Range_:
return int(v.Range.GetStart())
default:
return 0
}
}
func portToString(translatedPort *proto.PortInfo) string {
switch v := translatedPort.PortSelection.(type) {
case *proto.PortInfo_Port:
return fmt.Sprintf("%d", v.Port)
case *proto.PortInfo_Range_:
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
default:
return "No port specified"
}
}

View File

@@ -85,17 +85,11 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn)
var dnsLabelsReq []string
if dnsLabelsValidated != nil {
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {

View File

@@ -5,8 +5,8 @@ import (
"strings"
"testing"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
)

View File

@@ -1,173 +0,0 @@
package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var networksCMD = &cobra.Command{
Use: "networks",
Aliases: []string{"routes"},
Short: "Manage networks",
Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List networks",
Example: " netbird networks list",
Long: "List all available network routes.",
RunE: networksList,
}
var routesSelectCmd = &cobra.Command{
Use: "select network...|all",
Short: "Select network",
Long: "Select a list of networks by identifiers or 'all' to clear all selections and to accept all (including new) networks.\nDefault mode is replace, use -a to append to already selected networks.",
Example: " netbird networks select all\n netbird networks select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: networksSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect network...|all",
Short: "Deselect networks",
Long: "Deselect previously selected networks by identifiers or 'all' to disable accepting any networks.",
Example: " netbird networks deselect all\n netbird networks deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: networksDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing")
}
func networksList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListNetworks(cmd.Context(), &proto.ListNetworksRequest{})
if err != nil {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No networks available.")
return nil
}
printNetworks(cmd, resp)
return nil
}
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
cmd.Println("Available Networks:")
for _, route := range resp.Routes {
printNetwork(cmd, route)
}
}
func printNetwork(cmd *cobra.Command, route *proto.Network) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Network) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Network, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for resolvedDomain, ipList := range resolvedIPs {
cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", "))
}
}
func networksSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectNetworksRequest{
NetworkIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectNetworks(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message())
}
cmd.Println("Networks selected successfully.")
return nil
}
func networksDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectNetworksRequest{
NetworkIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectNetworks(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message())
}
cmd.Println("Networks deselected successfully.")
return nil
}

View File

@@ -1,33 +0,0 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

View File

@@ -38,7 +38,6 @@ const (
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access"
)
var (
@@ -74,7 +73,6 @@ var (
anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration
blockLANAccess bool
rootCmd = &cobra.Command{
Use: "netbird",
@@ -144,23 +142,19 @@ func init() {
rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(routesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
forwardingRulesCmd.AddCommand(forwardingRulesListCmd)
routesCmd.AddCommand(routesListCmd)
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+

View File

@@ -4,10 +4,6 @@ import (
"fmt"
"io"
"testing"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/iface"
)
func TestInitCommands(t *testing.T) {
@@ -38,44 +34,3 @@ func TestInitCommands(t *testing.T) {
})
}
}
func TestSetFlagsFromEnvVars(t *testing.T) {
var cmd = &cobra.Command{
Use: "netbird",
Long: "test",
SilenceUsage: true,
Run: func(cmd *cobra.Command, args []string) {
SetFlagsFromEnvVars(cmd)
},
}
cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`comma separated list of external IPs to map to the Wireguard interface`)
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
t.Setenv("NB_INTERFACE_NAME", "test-name")
t.Setenv("NB_ENABLE_ROSENPASS", "true")
t.Setenv("NB_WIREGUARD_PORT", "10000")
err := cmd.Execute()
if err != nil {
t.Fatalf("expected no error while running netbird command, got %v", err)
}
if len(natExternalIPs) != 2 {
t.Errorf("expected 2 external ips, got %d", len(natExternalIPs))
}
if natExternalIPs[0] != "abc" || natExternalIPs[1] != "dec" {
t.Errorf("expected abc,dec, got %s,%s", natExternalIPs[0], natExternalIPs[1])
}
if interfaceName != "test-name" {
t.Errorf("expected test-name, got %s", interfaceName)
}
if !rosenpassEnabled {
t.Errorf("expected rosenpassEnabled to be true, got false")
}
if wireguardPort != 10000 {
t.Errorf("expected wireguardPort to be 10000, got %d", wireguardPort)
}
}

174
client/cmd/route.go Normal file
View File

@@ -0,0 +1,174 @@
package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var routesCmd = &cobra.Command{
Use: "routes",
Short: "Manage network routes",
Long: `Commands to list, select, or deselect network routes.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List routes",
Example: " netbird routes list",
Long: "List all available network routes.",
RunE: routesList,
}
var routesSelectCmd = &cobra.Command{
Use: "select route...|all",
Short: "Select routes",
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: routesSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect route...|all",
Short: "Deselect routes",
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: routesDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
}
func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
if err != nil {
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No routes available.")
return nil
}
printRoutes(cmd, resp)
return nil
}
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
printRoute(cmd, route)
}
}
func printRoute(cmd *cobra.Command, route *proto.Route) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Route) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for _, domain := range domains {
if ipList, exists := resolvedIPs[domain]; exists {
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
}
}
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes selected successfully.")
return nil
}
func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes deselected successfully.")
return nil
}

View File

@@ -2,23 +2,18 @@ package cmd
import (
"context"
"sync"
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/server"
)
type program struct {
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
serverInstance *server.Server
serverInstanceMu sync.Mutex
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
}
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@@ -61,10 +61,6 @@ func (p *program) Start(svc service.Service) error {
}
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstanceMu.Lock()
p.serverInstance = serverInstance
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil {
log.Errorf("failed to serve daemon requests: %v", err)
@@ -74,16 +70,6 @@ func (p *program) Start(svc service.Service) error {
}
func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil {
in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in)
if err != nil {
log.Errorf("failed to stop daemon: %v", err)
}
}
p.serverInstanceMu.Unlock()
p.cancel()
if p.serv != nil {

View File

@@ -9,6 +9,7 @@ import (
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal"
@@ -72,7 +73,7 @@ var sshCmd = &cobra.Command{
go func() {
// blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
cmd.Printf("Error: %v\n", err)
log.Debug(err)
os.Exit(1)
}
cancel()

View File

@@ -1,181 +0,0 @@
package cmd
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var (
allFlag bool
)
var stateCmd = &cobra.Command{
Use: "state",
Short: "Manage daemon state",
Long: "Provides commands for managing and inspecting the Netbird daemon state.",
}
var stateListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all stored states",
Long: "Lists all registered states with their status and basic information.",
Example: " netbird state list",
RunE: stateList,
}
var stateCleanCmd = &cobra.Command{
Use: "clean [state-name]",
Short: "Clean stored states",
Long: `Clean specific state or all states. The daemon must not be running.
This will perform cleanup operations and remove the state.`,
Example: ` netbird state clean dns_state
netbird state clean --all`,
RunE: stateClean,
PreRunE: func(cmd *cobra.Command, args []string) error {
// Check mutual exclusivity between --all flag and state-name argument
if allFlag && len(args) > 0 {
return fmt.Errorf("cannot specify both --all flag and state name")
}
if !allFlag && len(args) != 1 {
return fmt.Errorf("requires a state name argument or --all flag")
}
return nil
},
}
var stateDeleteCmd = &cobra.Command{
Use: "delete [state-name]",
Short: "Delete stored states",
Long: `Delete specific state or all states from storage. The daemon must not be running.
This will remove the state without performing any cleanup operations.`,
Example: ` netbird state delete dns_state
netbird state delete --all`,
RunE: stateDelete,
PreRunE: func(cmd *cobra.Command, args []string) error {
// Check mutual exclusivity between --all flag and state-name argument
if allFlag && len(args) > 0 {
return fmt.Errorf("cannot specify both --all flag and state name")
}
if !allFlag && len(args) != 1 {
return fmt.Errorf("requires a state name argument or --all flag")
}
return nil
},
}
func init() {
rootCmd.AddCommand(stateCmd)
stateCmd.AddCommand(stateListCmd, stateCleanCmd, stateDeleteCmd)
stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states")
stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states")
}
func stateList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListStates(cmd.Context(), &proto.ListStatesRequest{})
if err != nil {
return fmt.Errorf("failed to list states: %v", status.Convert(err).Message())
}
cmd.Printf("\nStored states:\n\n")
for _, state := range resp.States {
cmd.Printf("- %s\n", state.Name)
}
return nil
}
func stateClean(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {
stateName = args[0]
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.CleanState(cmd.Context(), &proto.CleanStateRequest{
StateName: stateName,
All: allFlag,
})
if err != nil {
return fmt.Errorf("failed to clean state: %v", status.Convert(err).Message())
}
if resp.CleanedStates == 0 {
cmd.Println("No states were cleaned")
return nil
}
if allFlag {
cmd.Printf("Successfully cleaned %d states\n", resp.CleanedStates)
} else {
cmd.Printf("Successfully cleaned state %q\n", stateName)
}
return nil
}
func stateDelete(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {
stateName = args[0]
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DeleteState(cmd.Context(), &proto.DeleteStateRequest{
StateName: stateName,
All: allFlag,
})
if err != nil {
return fmt.Errorf("failed to delete state: %v", status.Convert(err).Message())
}
if resp.DeletedStates == 0 {
cmd.Println("No states were deleted")
return nil
}
if allFlag {
cmd.Printf("Successfully deleted %d states\n", resp.DeletedStates)
} else {
cmd.Printf("Successfully deleted state %q\n", stateName)
}
return nil
}

View File

@@ -2,20 +2,105 @@ package cmd
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"runtime"
"sort"
"strings"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
type peerStateDetailOutput struct {
FQDN string `json:"fqdn" yaml:"fqdn"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
Status string `json:"status" yaml:"status"`
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
ConnType string `json:"connectionType" yaml:"connectionType"`
Direct bool `json:"direct" yaml:"direct"`
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
Latency time.Duration `json:"latency" yaml:"latency"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
Routes []string `json:"routes" yaml:"routes"`
}
type peersStateOutput struct {
Total int `json:"total" yaml:"total"`
Connected int `json:"connected" yaml:"connected"`
Details []peerStateDetailOutput `json:"details" yaml:"details"`
}
type signalStateOutput struct {
URL string `json:"url" yaml:"url"`
Connected bool `json:"connected" yaml:"connected"`
Error string `json:"error" yaml:"error"`
}
type managementStateOutput struct {
URL string `json:"url" yaml:"url"`
Connected bool `json:"connected" yaml:"connected"`
Error string `json:"error" yaml:"error"`
}
type relayStateOutputDetail struct {
URI string `json:"uri" yaml:"uri"`
Available bool `json:"available" yaml:"available"`
Error string `json:"error" yaml:"error"`
}
type relayStateOutput struct {
Total int `json:"total" yaml:"total"`
Available int `json:"available" yaml:"available"`
Details []relayStateOutputDetail `json:"details" yaml:"details"`
}
type iceCandidateType struct {
Local string `json:"local" yaml:"local"`
Remote string `json:"remote" yaml:"remote"`
}
type nsServerGroupStateOutput struct {
Servers []string `json:"servers" yaml:"servers"`
Domains []string `json:"domains" yaml:"domains"`
Enabled bool `json:"enabled" yaml:"enabled"`
Error string `json:"error" yaml:"error"`
}
type statusOutputOverview struct {
Peers peersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
ManagementState managementStateOutput `json:"management" yaml:"management"`
SignalState signalStateOutput `json:"signal" yaml:"signal"`
Relays relayStateOutput `json:"relays" yaml:"relays"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
FQDN string `json:"fqdn" yaml:"fqdn"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
Routes []string `json:"routes" yaml:"routes"`
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
}
var (
detailFlag bool
ipv4Flag bool
@@ -86,17 +171,18 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
outputInformationHolder := convertToStatusOutputOverview(resp)
var statusOutputString string
switch {
case detailFlag:
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
case jsonFlag:
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
statusOutputString, err = parseToJSON(outputInformationHolder)
case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
statusOutputString, err = parseToYAML(outputInformationHolder)
default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
}
if err != nil {
@@ -126,6 +212,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
func parseFilters() error {
switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected":
if strings.ToLower(statusFilter) != "" {
@@ -162,6 +249,170 @@ func enableDetailFlagWhenFilterFlag() {
}
}
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
pbFullStatus := resp.GetFullStatus()
managementState := pbFullStatus.GetManagementState()
managementOverview := managementStateOutput{
URL: managementState.GetURL(),
Connected: managementState.GetConnected(),
Error: managementState.Error,
}
signalState := pbFullStatus.GetSignalState()
signalOverview := signalStateOutput{
URL: signalState.GetURL(),
Connected: signalState.GetConnected(),
Error: signalState.Error,
}
relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
overview := statusOutputOverview{
Peers: peersOverview,
CliVersion: version.NetbirdVersion(),
DaemonVersion: resp.GetDaemonVersion(),
ManagementState: managementOverview,
SignalState: signalOverview,
Relays: relayOverview,
IP: pbFullStatus.GetLocalPeerState().GetIP(),
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
}
if anonymizeFlag {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
anonymizeOverview(anonymizer, &overview)
}
return overview
}
func mapRelays(relays []*proto.RelayState) relayStateOutput {
var relayStateDetail []relayStateOutputDetail
var relaysAvailable int
for _, relay := range relays {
available := relay.GetAvailable()
relayStateDetail = append(relayStateDetail,
relayStateOutputDetail{
URI: relay.URI,
Available: available,
Error: relay.GetError(),
},
)
if available {
relaysAvailable++
}
}
return relayStateOutput{
Total: len(relays),
Available: relaysAvailable,
Details: relayStateDetail,
}
}
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
for _, pbNsGroupServer := range servers {
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
Servers: pbNsGroupServer.GetServers(),
Domains: pbNsGroupServer.GetDomains(),
Enabled: pbNsGroupServer.GetEnabled(),
Error: pbNsGroupServer.GetError(),
})
}
return mappedNSGroups
}
func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput
localICE := ""
remoteICE := ""
localICEEndpoint := ""
remoteICEEndpoint := ""
connType := ""
peersConnected := 0
lastHandshake := time.Time{}
transferReceived := int64(0)
transferSent := int64(0)
for _, pbPeerState := range peers {
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected) {
continue
}
if isPeerConnected {
peersConnected++
localICE = pbPeerState.GetLocalIceCandidateType()
remoteICE = pbPeerState.GetRemoteIceCandidateType()
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
connType = "P2P"
if pbPeerState.Relayed {
connType = "Relayed"
}
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx()
transferSent = pbPeerState.GetBytesTx()
}
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
peerState := peerStateDetailOutput{
IP: pbPeerState.GetIP(),
PubKey: pbPeerState.GetPubKey(),
Status: pbPeerState.GetConnStatus(),
LastStatusUpdate: timeLocal,
ConnType: connType,
Direct: pbPeerState.GetDirect(),
IceCandidateType: iceCandidateType{
Local: localICE,
Remote: remoteICE,
},
IceCandidateEndpoint: iceCandidateType{
Local: localICEEndpoint,
Remote: remoteICEEndpoint,
},
FQDN: pbPeerState.GetFqdn(),
LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived,
TransferSent: transferSent,
Latency: pbPeerState.GetLatency().AsDuration(),
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
Routes: pbPeerState.GetRoutes(),
}
peersStateDetail = append(peersStateDetail, peerState)
}
sortPeersByIP(peersStateDetail)
peersOverview := peersStateOutput{
Total: len(peersStateDetail),
Connected: peersConnected,
Details: peersStateDetail,
}
return peersOverview
}
func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
if len(peersStateDetail) > 0 {
sort.SliceStable(peersStateDetail, func(i, j int) bool {
iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
return iAddr.Compare(jAddr) == -1
})
}
}
func parseInterfaceIP(interfaceIP string) string {
ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil {
@@ -169,3 +420,431 @@ func parseInterfaceIP(interfaceIP string) string {
}
return fmt.Sprintf("%s\n", ip)
}
func parseToJSON(overview statusOutputOverview) (string, error) {
jsonBytes, err := json.Marshal(overview)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
func parseToYAML(overview statusOutputOverview) (string, error) {
yamlBytes, err := yaml.Marshal(overview)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
var managementConnString string
if overview.ManagementState.Connected {
managementConnString = "Connected"
if showURL {
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
}
} else {
managementConnString = "Disconnected"
if overview.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
}
}
var signalConnString string
if overview.SignalState.Connected {
signalConnString = "Connected"
if showURL {
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
}
} else {
signalConnString = "Disconnected"
if overview.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
}
}
interfaceTypeString := "Userspace"
interfaceIP := overview.IP
if overview.KernelInterface {
interfaceTypeString = "Kernel"
} else if overview.IP == "" {
interfaceTypeString = "N/A"
interfaceIP = "N/A"
}
var relaysString string
if showRelays {
for _, relay := range overview.Relays.Details {
available := "Available"
reason := ""
if !relay.Available {
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
}
routes := "-"
if len(overview.Routes) > 0 {
sort.Strings(overview.Routes)
routes = strings.Join(overview.Routes, ", ")
}
var dnsServersString string
if showNameServers {
for _, nsServerGroup := range overview.NSServerGroups {
enabled := "Available"
if !nsServerGroup.Enabled {
enabled = "Unavailable"
}
errorString := ""
if nsServerGroup.Error != "" {
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
errorString = strings.TrimSpace(errorString)
}
domainsString := strings.Join(nsServerGroup.Domains, ", ")
if domainsString == "" {
domainsString = "." // Show "." for the default zone
}
dnsServersString += fmt.Sprintf(
"\n [%s] for [%s] is %s%s",
strings.Join(nsServerGroup.Servers, ", "),
domainsString,
enabled,
errorString,
)
}
} else {
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
}
rosenpassEnabledStatus := "false"
if overview.RosenpassEnabled {
rosenpassEnabledStatus = "true"
if overview.RosenpassPermissive {
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
goarm := ""
if goarch == "arm" {
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
summary := fmt.Sprintf(
"OS: %s\n"+
"Daemon version: %s\n"+
"CLI version: %s\n"+
"Management: %s\n"+
"Signal: %s\n"+
"Relays: %s\n"+
"Nameservers: %s\n"+
"FQDN: %s\n"+
"NetBird IP: %s\n"+
"Interface type: %s\n"+
"Quantum resistance: %s\n"+
"Routes: %s\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
version.NetbirdVersion(),
managementConnString,
signalConnString,
relaysString,
dnsServersString,
overview.FQDN,
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
routes,
peersCountString,
)
return summary
}
func parseToFullDetailSummary(overview statusOutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
summary := parseGeneralSummary(overview, true, true, true)
return fmt.Sprintf(
"Peers detail:"+
"%s\n"+
"%s",
parsedPeersString,
summary,
)
}
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
var (
peersString = ""
)
for _, peerState := range peers.Details {
localICE := "-"
if peerState.IceCandidateType.Local != "" {
localICE = peerState.IceCandidateType.Local
}
remoteICE := "-"
if peerState.IceCandidateType.Remote != "" {
remoteICE = peerState.IceCandidateType.Remote
}
localICEEndpoint := "-"
if peerState.IceCandidateEndpoint.Local != "" {
localICEEndpoint = peerState.IceCandidateEndpoint.Local
}
remoteICEEndpoint := "-"
if peerState.IceCandidateEndpoint.Remote != "" {
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
}
rosenpassEnabledStatus := "false"
if rosenpassEnabled {
if peerState.RosenpassEnabled {
rosenpassEnabledStatus = "true"
} else {
if rosenpassPermissive {
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
} else {
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
}
}
} else {
if peerState.RosenpassEnabled {
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
}
}
routes := "-"
if len(peerState.Routes) > 0 {
sort.Strings(peerState.Routes)
routes = strings.Join(peerState.Routes, ", ")
}
peerString := fmt.Sprintf(
"\n %s:\n"+
" NetBird IP: %s\n"+
" Public key: %s\n"+
" Status: %s\n"+
" -- detail --\n"+
" Connection type: %s\n"+
" Direct: %t\n"+
" ICE candidate (Local/Remote): %s/%s\n"+
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
" Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+
" Quantum resistance: %s\n"+
" Routes: %s\n"+
" Latency: %s\n",
peerState.FQDN,
peerState.IP,
peerState.PubKey,
peerState.Status,
peerState.ConnType,
peerState.Direct,
localICE,
remoteICE,
localICEEndpoint,
remoteICEEndpoint,
timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent),
rosenpassEnabledStatus,
routes,
peerState.Latency.String(),
)
peersString += peerString
}
return peersString
}
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false
ipEval := false
nameEval := false
if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter)
if lowerStatusFilter == "disconnected" && isConnected {
statusEval = true
} else if lowerStatusFilter == "connected" && !isConnected {
statusEval = true
}
}
if len(ipsFilter) > 0 {
_, ok := ipsFilterMap[peerState.IP]
if !ok {
ipEval = true
}
}
if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap {
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = true
break
}
}
}
return statusEval || ipEval || nameEval
}
func toIEC(b int64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB",
float64(b)/float64(div), "KMGTPE"[exp])
}
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
count := 0
for _, server := range dnsServers {
if server.Enabled {
count++
}
}
return count
}
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
func timeAgo(t time.Time) string {
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
return "-"
}
duration := time.Since(t)
switch {
case duration < time.Second:
return "Now"
case duration < time.Minute:
seconds := int(duration.Seconds())
if seconds == 1 {
return "1 second ago"
}
return fmt.Sprintf("%d seconds ago", seconds)
case duration < time.Hour:
minutes := int(duration.Minutes())
seconds := int(duration.Seconds()) % 60
if minutes == 1 {
if seconds == 1 {
return "1 minute, 1 second ago"
} else if seconds > 0 {
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
}
return "1 minute ago"
}
if seconds > 0 {
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
}
return fmt.Sprintf("%d minutes ago", minutes)
case duration < 24*time.Hour:
hours := int(duration.Hours())
minutes := int(duration.Minutes()) % 60
if hours == 1 {
if minutes == 1 {
return "1 hour, 1 minute ago"
} else if minutes > 0 {
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
}
return "1 hour ago"
}
if minutes > 0 {
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
}
return fmt.Sprintf("%d hours ago", hours)
}
days := int(duration.Hours()) / 24
hours := int(duration.Hours()) % 24
if days == 1 {
if hours == 1 {
return "1 day, 1 hour ago"
} else if hours > 0 {
return fmt.Sprintf("1 day, %d hours ago", hours)
}
return "1 day ago"
}
if hours > 0 {
return fmt.Sprintf("%d days, %d hours ago", days, hours)
}
return fmt.Sprintf("%d days ago", days)
}
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
}
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeRoute(route)
}
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
anonymizePeerDetail(a, &peer)
overview.Peers.Details[i] = peer
}
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
overview.IP = a.AnonymizeIPString(overview.IP)
for i, detail := range overview.Relays.Details {
detail.URI = a.AnonymizeURI(detail.URI)
detail.Error = a.AnonymizeString(detail.Error)
overview.Relays.Details[i] = detail
}
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
}
for j, ns := range nsGroup.Servers {
host, port, err := net.SplitHostPort(ns)
if err == nil {
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
}
}
for i, route := range overview.Routes {
overview.Routes[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}

View File

@@ -1,11 +1,579 @@
package cmd
import (
"bytes"
"encoding/json"
"fmt"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
func init() {
loc, err := time.LoadLocation("UTC")
if err != nil {
panic(err)
}
time.Local = loc
}
var resp = &proto.StatusResponse{
Status: "Connected",
FullStatus: &proto.FullStatus{
Peers: []*proto.PeerState{
{
IP: "192.168.178.101",
PubKey: "Pubkey1",
Fqdn: "peer-1.awesome-domain.com",
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
Relayed: false,
Direct: true,
LocalIceCandidateType: "",
RemoteIceCandidateType: "",
LocalIceCandidateEndpoint: "",
RemoteIceCandidateEndpoint: "",
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
BytesRx: 200,
BytesTx: 100,
Routes: []string{
"10.1.0.0/24",
},
Latency: durationpb.New(time.Duration(10000000)),
},
{
IP: "192.168.178.102",
PubKey: "Pubkey2",
Fqdn: "peer-2.awesome-domain.com",
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
Relayed: true,
Direct: false,
LocalIceCandidateType: "relay",
RemoteIceCandidateType: "prflx",
LocalIceCandidateEndpoint: "10.0.0.1:10001",
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
BytesRx: 2000,
BytesTx: 1000,
Latency: durationpb.New(time.Duration(10000000)),
},
},
ManagementState: &proto.ManagementState{
URL: "my-awesome-management.com:443",
Connected: true,
Error: "",
},
SignalState: &proto.SignalState{
URL: "my-awesome-signal.com:443",
Connected: true,
Error: "",
},
Relays: []*proto.RelayState{
{
URI: "stun:my-awesome-stun.com:3478",
Available: true,
Error: "",
},
{
URI: "turns:my-awesome-turn.com:443?transport=tcp",
Available: false,
Error: "context: deadline exceeded",
},
},
LocalPeerState: &proto.LocalPeerState{
IP: "192.168.178.100/16",
PubKey: "Some-Pub-Key",
KernelInterface: true,
Fqdn: "some-localhost.awesome-domain.com",
Routes: []string{
"10.10.0.0/24",
},
},
DnsServers: []*proto.NSGroupState{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
},
DaemonVersion: "0.14.1",
}
var overview = statusOutputOverview{
Peers: peersStateOutput{
Total: 2,
Connected: 2,
Details: []peerStateDetailOutput{
{
IP: "192.168.178.101",
PubKey: "Pubkey1",
FQDN: "peer-1.awesome-domain.com",
Status: "Connected",
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
ConnType: "P2P",
Direct: true,
IceCandidateType: iceCandidateType{
Local: "",
Remote: "",
},
IceCandidateEndpoint: iceCandidateType{
Local: "",
Remote: "",
},
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
TransferReceived: 200,
TransferSent: 100,
Routes: []string{
"10.1.0.0/24",
},
Latency: time.Duration(10000000),
},
{
IP: "192.168.178.102",
PubKey: "Pubkey2",
FQDN: "peer-2.awesome-domain.com",
Status: "Connected",
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
ConnType: "Relayed",
Direct: false,
IceCandidateType: iceCandidateType{
Local: "relay",
Remote: "prflx",
},
IceCandidateEndpoint: iceCandidateType{
Local: "10.0.0.1:10001",
Remote: "10.0.10.1:10002",
},
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
TransferReceived: 2000,
TransferSent: 1000,
Latency: time.Duration(10000000),
},
},
},
CliVersion: version.NetbirdVersion(),
DaemonVersion: "0.14.1",
ManagementState: managementStateOutput{
URL: "my-awesome-management.com:443",
Connected: true,
Error: "",
},
SignalState: signalStateOutput{
URL: "my-awesome-signal.com:443",
Connected: true,
Error: "",
},
Relays: relayStateOutput{
Total: 2,
Available: 1,
Details: []relayStateOutputDetail{
{
URI: "stun:my-awesome-stun.com:3478",
Available: true,
Error: "",
},
{
URI: "turns:my-awesome-turn.com:443?transport=tcp",
Available: false,
Error: "context: deadline exceeded",
},
},
},
IP: "192.168.178.100/16",
PubKey: "Some-Pub-Key",
KernelInterface: true,
FQDN: "some-localhost.awesome-domain.com",
NSServerGroups: []nsServerGroupStateOutput{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
Routes: []string{
"10.10.0.0/24",
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := convertToStatusOutputOverview(resp)
assert.Equal(t, overview, convertedResult)
}
func TestSortingOfPeers(t *testing.T) {
peers := []peerStateDetailOutput{
{
IP: "192.168.178.104",
},
{
IP: "192.168.178.102",
},
{
IP: "192.168.178.101",
},
{
IP: "192.168.178.105",
},
{
IP: "192.168.178.103",
},
}
sortPeersByIP(peers)
assert.Equal(t, peers[3].IP, "192.168.178.104")
}
func TestParsingToJSON(t *testing.T) {
jsonString, _ := parseToJSON(overview)
//@formatter:off
expectedJSONString := `
{
"peers": {
"total": 2,
"connected": 2,
"details": [
{
"fqdn": "peer-1.awesome-domain.com",
"netbirdIp": "192.168.178.101",
"publicKey": "Pubkey1",
"status": "Connected",
"lastStatusUpdate": "2001-01-01T01:01:01Z",
"connectionType": "P2P",
"direct": true,
"iceCandidateType": {
"local": "",
"remote": ""
},
"iceCandidateEndpoint": {
"local": "",
"remote": ""
},
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200,
"transferSent": 100,
"latency": 10000000,
"quantumResistance": false,
"routes": [
"10.1.0.0/24"
]
},
{
"fqdn": "peer-2.awesome-domain.com",
"netbirdIp": "192.168.178.102",
"publicKey": "Pubkey2",
"status": "Connected",
"lastStatusUpdate": "2002-02-02T02:02:02Z",
"connectionType": "Relayed",
"direct": false,
"iceCandidateType": {
"local": "relay",
"remote": "prflx"
},
"iceCandidateEndpoint": {
"local": "10.0.0.1:10001",
"remote": "10.0.10.1:10002"
},
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000,
"transferSent": 1000,
"latency": 10000000,
"quantumResistance": false,
"routes": null
}
]
},
"cliVersion": "development",
"daemonVersion": "0.14.1",
"management": {
"url": "my-awesome-management.com:443",
"connected": true,
"error": ""
},
"signal": {
"url": "my-awesome-signal.com:443",
"connected": true,
"error": ""
},
"relays": {
"total": 2,
"available": 1,
"details": [
{
"uri": "stun:my-awesome-stun.com:3478",
"available": true,
"error": ""
},
{
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
"available": false,
"error": "context: deadline exceeded"
}
]
},
"netbirdIp": "192.168.178.100/16",
"publicKey": "Some-Pub-Key",
"usesKernelInterface": true,
"fqdn": "some-localhost.awesome-domain.com",
"quantumResistance": false,
"quantumResistancePermissive": false,
"routes": [
"10.10.0.0/24"
],
"dnsServers": [
{
"servers": [
"8.8.8.8:53"
],
"domains": null,
"enabled": true,
"error": ""
},
{
"servers": [
"1.1.1.1:53",
"2.2.2.2:53"
],
"domains": [
"example.com",
"example.net"
],
"enabled": false,
"error": "timeout"
}
]
}`
// @formatter:on
var expectedJSON bytes.Buffer
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
assert.Equal(t, expectedJSON.String(), jsonString)
}
func TestParsingToYAML(t *testing.T) {
yaml, _ := parseToYAML(overview)
expectedYAML :=
`peers:
total: 2
connected: 2
details:
- fqdn: peer-1.awesome-domain.com
netbirdIp: 192.168.178.101
publicKey: Pubkey1
status: Connected
lastStatusUpdate: 2001-01-01T01:01:01Z
connectionType: P2P
direct: true
iceCandidateType:
local: ""
remote: ""
iceCandidateEndpoint:
local: ""
remote: ""
lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200
transferSent: 100
latency: 10ms
quantumResistance: false
routes:
- 10.1.0.0/24
- fqdn: peer-2.awesome-domain.com
netbirdIp: 192.168.178.102
publicKey: Pubkey2
status: Connected
lastStatusUpdate: 2002-02-02T02:02:02Z
connectionType: Relayed
direct: false
iceCandidateType:
local: relay
remote: prflx
iceCandidateEndpoint:
local: 10.0.0.1:10001
remote: 10.0.10.1:10002
lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000
transferSent: 1000
latency: 10ms
quantumResistance: false
routes: []
cliVersion: development
daemonVersion: 0.14.1
management:
url: my-awesome-management.com:443
connected: true
error: ""
signal:
url: my-awesome-signal.com:443
connected: true
error: ""
relays:
total: 2
available: 1
details:
- uri: stun:my-awesome-stun.com:3478
available: true
error: ""
- uri: turns:my-awesome-turn.com:443?transport=tcp
available: false
error: 'context: deadline exceeded'
netbirdIp: 192.168.178.100/16
publicKey: Some-Pub-Key
usesKernelInterface: true
fqdn: some-localhost.awesome-domain.com
quantumResistance: false
quantumResistancePermissive: false
routes:
- 10.10.0.0/24
dnsServers:
- servers:
- 8.8.8.8:53
domains: []
enabled: true
error: ""
- servers:
- 1.1.1.1:53
- 2.2.2.2:53
domains:
- example.com
- example.net
enabled: false
error: timeout
`
assert.Equal(t, expectedYAML, yaml)
}
func TestParsingToDetail(t *testing.T) {
// Calculate time ago based on the fixture dates
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := parseToFullDetailSummary(overview)
expectedDetail := fmt.Sprintf(
`Peers detail:
peer-1.awesome-domain.com:
NetBird IP: 192.168.178.101
Public key: Pubkey1
Status: Connected
-- detail --
Connection type: P2P
Direct: true
ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/-
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B
Quantum resistance: false
Routes: 10.1.0.0/24
Latency: 10ms
peer-2.awesome-domain.com:
NetBird IP: 192.168.178.102
Public key: Pubkey2
Status: Connected
-- detail --
Connection type: Relayed
Direct: false
ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false
Routes: -
Latency: 10ms
OS: %s/%s
Daemon version: 0.14.1
CLI version: %s
Management: Connected to my-awesome-management.com:443
Signal: Connected to my-awesome-signal.com:443
Relays:
[stun:my-awesome-stun.com:3478] is Available
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
Nameservers:
[8.8.8.8:53] for [.] is Available
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
assert.Equal(t, expectedDetail, detail)
}
func TestParsingToShortVersion(t *testing.T) {
shortVersion := parseGeneralSummary(overview, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1
CLI version: development
Management: Connected
Signal: Connected
Relays: 1/2 Available
Nameservers: 1/2 Available
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Peers count: 2/2 Connected
`
assert.Equal(t, expectedString, shortVersion)
}
func TestParsingOfIP(t *testing.T) {
InterfaceIP := "192.168.178.123/16"
@@ -13,3 +581,31 @@ func TestParsingOfIP(t *testing.T) {
assert.Equal(t, "192.168.178.123\n", parsedIP)
}
func TestTimeAgo(t *testing.T) {
now := time.Now()
cases := []struct {
name string
input time.Time
expected string
}{
{"Now", now, "Now"},
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
{"Zero time", time.Time{}, "-"},
{"Unix zero time", time.Unix(0, 0), "-"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := timeAgo(tc.input)
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
})
}
}

View File

@@ -1,31 +0,0 @@
package cmd
// Flag constants for system configuration
const (
disableClientRoutesFlag = "disable-client-routes"
disableServerRoutesFlag = "disable-server-routes"
disableDNSFlag = "disable-dns"
disableFirewallFlag = "disable-firewall"
)
var (
disableClientRoutes bool
disableServerRoutes bool
disableDNS bool
disableFirewall bool
)
func init() {
// Add system flags to upCmd
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
"Disable DNS. If enabled, the client won't configure DNS settings.")
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
}

View File

@@ -3,6 +3,7 @@ package cmd
import (
"context"
"net"
"path/filepath"
"testing"
"time"
@@ -10,9 +11,6 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
@@ -36,12 +34,18 @@ func startTestingServices(t *testing.T) string {
if err != nil {
t.Fatal(err)
}
testDir := t.TempDir()
config.Datadir = testDir
err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json"))
if err != nil {
t.Fatal(err)
}
_, signalLis := startSignal(t)
signalAddr := signalLis.Addr().String()
config.Signal.URI = signalAddr
_, mgmLis := startManagement(t, config, "../testdata/store.sql")
_, mgmLis := startManagement(t, config)
mgmAddr := mgmLis.Addr().String()
return mgmAddr
}
@@ -53,7 +57,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
srv, err := sig.NewServer(context.Background(), otel.Meter(""))
srv, err := sig.NewServer(otel.Meter(""))
require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv)
@@ -66,7 +70,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis
}
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
@@ -74,7 +78,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
t.Fatal(err)
}
@@ -90,13 +94,12 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
if err != nil {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,137 +0,0 @@
package cmd
import (
"fmt"
"math/rand"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var traceCmd = &cobra.Command{
Use: "trace <direction> <source-ip> <dest-ip>",
Short: "Trace a packet through the firewall",
Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3),
RunE: tracePacket,
}
func init() {
debugCmd.AddCommand(traceCmd)
traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)")
traceCmd.Flags().Uint16("sport", 0, "Source port")
traceCmd.Flags().Uint16("dport", 0, "Destination port")
traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type")
traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code")
traceCmd.Flags().Bool("syn", false, "TCP SYN flag")
traceCmd.Flags().Bool("ack", false, "TCP ACK flag")
traceCmd.Flags().Bool("fin", false, "TCP FIN flag")
traceCmd.Flags().Bool("rst", false, "TCP RST flag")
traceCmd.Flags().Bool("psh", false, "TCP PSH flag")
traceCmd.Flags().Bool("urg", false, "TCP URG flag")
}
func tracePacket(cmd *cobra.Command, args []string) error {
direction := strings.ToLower(args[0])
if direction != "in" && direction != "out" {
return fmt.Errorf("invalid direction: use 'in' or 'out'")
}
protocol := cmd.Flag("protocol").Value.String()
if protocol != "tcp" && protocol != "udp" && protocol != "icmp" {
return fmt.Errorf("invalid protocol: use tcp/udp/icmp")
}
sport, err := cmd.Flags().GetUint16("sport")
if err != nil {
return fmt.Errorf("invalid source port: %v", err)
}
dport, err := cmd.Flags().GetUint16("dport")
if err != nil {
return fmt.Errorf("invalid destination port: %v", err)
}
// For TCP/UDP, generate random ephemeral port (49152-65535) if not specified
if protocol != "icmp" {
if sport == 0 {
sport = uint16(rand.Intn(16383) + 49152)
}
if dport == 0 {
dport = uint16(rand.Intn(16383) + 49152)
}
}
var tcpFlags *proto.TCPFlags
if protocol == "tcp" {
syn, _ := cmd.Flags().GetBool("syn")
ack, _ := cmd.Flags().GetBool("ack")
fin, _ := cmd.Flags().GetBool("fin")
rst, _ := cmd.Flags().GetBool("rst")
psh, _ := cmd.Flags().GetBool("psh")
urg, _ := cmd.Flags().GetBool("urg")
tcpFlags = &proto.TCPFlags{
Syn: syn,
Ack: ack,
Fin: fin,
Rst: rst,
Psh: psh,
Urg: urg,
}
}
icmpType, _ := cmd.Flags().GetUint32("icmp-type")
icmpCode, _ := cmd.Flags().GetUint32("icmp-code")
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{
SourceIp: args[1],
DestinationIp: args[2],
Protocol: protocol,
SourcePort: uint32(sport),
DestinationPort: uint32(dport),
Direction: direction,
TcpFlags: tcpFlags,
IcmpType: &icmpType,
IcmpCode: &icmpCode,
})
if err != nil {
return fmt.Errorf("trace failed: %v", status.Convert(err).Message())
}
printTrace(cmd, args[1], args[2], protocol, sport, dport, resp)
return nil
}
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil {
cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails)
} else {
cmd.Printf("%s: %s\n", stage.Name, stage.Message)
}
}
disposition := map[bool]string{
true: "\033[32mALLOWED\033[0m", // Green
false: "\033[31mDENIED\033[0m", // Red
}[resp.FinalDisposition]
cmd.Printf("\nFinal disposition: %s\n", disposition)
}

View File

@@ -15,12 +15,11 @@ import (
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
)
@@ -30,16 +29,9 @@ const (
interfaceInputType
)
const (
dnsLabelsFlag = "extra-dns-labels"
)
var (
foregroundMode bool
dnsLabels []string
dnsLabelsValidated domain.List
upCmd = &cobra.Command{
foregroundMode bool
upCmd = &cobra.Command{
Use: "up",
Short: "install, login and start Netbird client",
RunE: upFunc,
@@ -56,15 +48,6 @@ func init() {
)
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+
`You can specify a comma-separated list of up to 32 labels. `+
`An empty string "" clears the previous configuration. `+
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`,
)
}
func upFunc(cmd *cobra.Command, args []string) error {
@@ -83,11 +66,6 @@ func upFunc(cmd *cobra.Command, args []string) error {
return err
}
dnsLabelsValidated, err = validateDnsLabels(dnsLabels)
if err != nil {
return err
}
ctx := internal.CtxInitState(cmd.Context())
if hostName != "" {
@@ -119,7 +97,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
@@ -170,23 +147,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.DNSRouteInterval = &dnsRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
@@ -208,11 +168,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run(nil)
connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
return connectClient.Run()
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
@@ -262,8 +219,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -306,23 +261,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
}
if cmd.Flag(disableClientRoutesFlag).Changed {
loginRequest.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
loginRequest.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
loginRequest.DisableDns = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
loginRequest.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
loginRequest.BlockLanAccess = &blockLANAccess
}
var loginErr error
var loginResp *proto.LoginResponse
@@ -454,24 +392,6 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
return parsed, nil
}
func validateDnsLabels(labels []string) (domain.List, error) {
var (
domains domain.List
err error
)
if len(labels) == 0 {
return domains, nil
}
domains, err = domain.ValidateDomains(labels)
if err != nil {
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
}
return domains, nil
}
func isValidAddrPort(input string) bool {
if input == "" {
return true

View File

@@ -1,24 +0,0 @@
package configs
import (
"os"
"path/filepath"
"runtime"
)
var StateDir string
func init() {
StateDir = os.Getenv("NB_STATE_DIR")
if StateDir != "" {
return
}
switch runtime.GOOS {
case "windows":
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
case "darwin", "linux":
StateDir = "/var/lib/netbird"
case "freebsd", "openbsd", "netbsd", "dragonfly":
StateDir = "/var/db/netbird"
}
}

View File

@@ -1,167 +0,0 @@
// Package embed provides a way to embed the NetBird client directly
// into Go programs without requiring a separate NetBird client installation.
package embed
// Basic Usage:
//
// client, err := embed.New(embed.Options{
// DeviceName: "my-service",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// })
// if err != nil {
// log.Fatal(err)
// }
//
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// Complete HTTP Server Example:
//
// package main
//
// import (
// "context"
// "fmt"
// "log"
// "net/http"
// "os"
// "os/signal"
// "syscall"
// "time"
//
// netbird "github.com/netbirdio/netbird/client/embed"
// )
//
// func main() {
// // Create client with setup key and device name
// client, err := netbird.New(netbird.Options{
// DeviceName: "http-server",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// LogOutput: io.Discard,
// })
// if err != nil {
// log.Fatal(err)
// }
//
// // Start with timeout
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// // Create HTTP server
// mux := http.NewServeMux()
// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// fmt.Printf("Request from %s: %s %s\n", r.RemoteAddr, r.Method, r.URL.Path)
// fmt.Fprintf(w, "Hello from netbird!")
// })
//
// // Listen on netbird network
// l, err := client.ListenTCP(":8080")
// if err != nil {
// log.Fatal(err)
// }
//
// server := &http.Server{Handler: mux}
// go func() {
// if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) {
// log.Printf("HTTP server error: %v", err)
// }
// }()
//
// log.Printf("HTTP server listening on netbird network port 8080")
//
// // Handle shutdown
// stop := make(chan os.Signal, 1)
// signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
// <-stop
//
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// defer cancel()
//
// if err := server.Shutdown(shutdownCtx); err != nil {
// log.Printf("HTTP shutdown error: %v", err)
// }
// if err := client.Stop(shutdownCtx); err != nil {
// log.Printf("Netbird shutdown error: %v", err)
// }
// }
//
// Complete HTTP Client Example:
//
// package main
//
// import (
// "context"
// "fmt"
// "io"
// "log"
// "os"
// "time"
//
// netbird "github.com/netbirdio/netbird/client/embed"
// )
//
// func main() {
// // Create client with setup key and device name
// client, err := netbird.New(netbird.Options{
// DeviceName: "http-client",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// LogOutput: io.Discard,
// })
// if err != nil {
// log.Fatal(err)
// }
//
// // Start with timeout
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
//
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// // Create HTTP client that uses netbird network
// httpClient := client.NewHTTPClient()
// httpClient.Timeout = 10 * time.Second
//
// // Make request to server in netbird network
// target := os.Getenv("NB_TARGET")
// resp, err := httpClient.Get(target)
// if err != nil {
// log.Fatal(err)
// }
// defer resp.Body.Close()
//
// // Read and print response
// body, err := io.ReadAll(resp.Body)
// if err != nil {
// log.Fatal(err)
// }
//
// fmt.Printf("Response from server: %s\n", string(body))
//
// // Clean shutdown
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// defer cancel()
//
// if err := client.Stop(shutdownCtx); err != nil {
// log.Printf("Netbird shutdown error: %v", err)
// }
// }
//
// The package provides several methods for network operations:
// - Dial: Creates outbound connections
// - ListenTCP: Creates TCP listeners
// - ListenUDP: Creates UDP listeners
//
// By default, the embed package uses userspace networking mode, which doesn't
// require root/admin privileges. For production deployments, consider setting
// appropriate config and state paths for persistence.

View File

@@ -1,293 +0,0 @@
package embed
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
"sync"
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system"
)
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
// Client manages a netbird embedded client instance
type Client struct {
deviceName string
config *internal.Config
mu sync.Mutex
cancel context.CancelFunc
setupKey string
connect *internal.ConnectClient
}
// Options configures a new Client
type Options struct {
// DeviceName is this peer's name in the network
DeviceName string
// SetupKey is used for authentication
SetupKey string
// ManagementURL overrides the default management server URL
ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface
PreSharedKey string
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
LogOutput io.Writer
// LogLevel sets the logging level (defaults to info if empty)
LogLevel string
// NoUserspace disables the userspace networking mode. Needs admin/root privileges
NoUserspace bool
// ConfigPath is the path to the netbird config file. If empty, the config will be stored in memory and not persisted.
ConfigPath string
// StatePath is the path to the netbird state file
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
}
// New creates a new netbird embedded client
func New(opts Options) (*Client, error) {
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
if opts.LogLevel != "" {
level, err := logrus.ParseLevel(opts.LogLevel)
if err != nil {
return nil, fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
}
if !opts.NoUserspace {
if err := os.Setenv(netstack.EnvUseNetstackMode, "true"); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
if err := os.Setenv(netstack.EnvSkipProxy, "true"); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
}
if opts.StatePath != "" {
// TODO: Disable state if path not provided
if err := os.Setenv("NB_DNS_STATE_FILE", opts.StatePath); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
}
t := true
var config *internal.Config
var err error
input := internal.ConfigInput{
ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL,
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
}
if opts.ConfigPath != "" {
config, err = internal.UpdateOrCreateConfig(input)
} else {
config, err = internal.CreateInMemoryConfig(input)
}
if err != nil {
return nil, fmt.Errorf("create config: %w", err)
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
config: config,
}, nil
}
// Start begins client operation and blocks until the engine has been started successfully or a startup error occurs.
// Pass a context with a deadline to limit the time spent waiting for the engine to start.
func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cancel != nil {
return ErrClientAlreadyStarted
}
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
return fmt.Errorf("login: %w", err)
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
clientErr <- err
}
}()
select {
case <-startCtx.Done():
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}
return startCtx.Err()
case err := <-clientErr:
return fmt.Errorf("startup: %w", err)
case <-run:
}
c.connect = client
return nil
}
// Stop gracefully stops the client.
// Pass a context with a deadline to limit the time spent waiting for the engine to stop.
func (c *Client) Stop(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connect == nil {
return ErrClientNotStarted
}
done := make(chan error, 1)
go func() {
done <- c.connect.Stop()
}()
select {
case <-ctx.Done():
c.cancel = nil
return ctx.Err()
case err := <-done:
c.cancel = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
return nil
}
}
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, ErrClientNotStarted
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, errors.New("engine not started")
}
nsnet, err := engine.GetNet()
if err != nil {
return nil, fmt.Errorf("get net: %w", err)
}
return nsnet.DialContext(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet()
if err != nil {
return nil, err
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, fmt.Errorf("resolve: %w", err)
}
return nsnet.ListenTCP(tcpAddr)
}
// ListenUDP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet()
if err != nil {
return nil, err
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, fmt.Errorf("resolve: %w", err)
}
return nsnet.ListenUDP(udpAddr)
}
// NewHTTPClient returns a configured http.Client that uses the netbird network for requests.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) NewHTTPClient() *http.Client {
transport := &http.Transport{
DialContext: c.Dial,
}
return &http.Client{
Transport: transport,
}
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, netip.Addr{}, errors.New("client not started")
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, netip.Addr{}, errors.New("engine not started")
}
addr, err := engine.Address()
if err != nil {
return nil, netip.Addr{}, fmt.Errorf("engine address: %w", err)
}
nsnet, err := engine.GetNet()
if err != nil {
return nil, netip.Addr{}, fmt.Errorf("get net: %w", err)
}
return nsnet, addr, nil
}

View File

@@ -8,8 +8,8 @@ import (
)
func formatError(es []error) string {
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
if len(es) == 0 {
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))

View File

@@ -3,6 +3,7 @@
package firewall
import (
"context"
"fmt"
"runtime"
@@ -10,18 +11,16 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
fm, err := uspfilter.Create(iface)
if err != nil {
return nil, err
}

View File

@@ -3,7 +3,7 @@
package firewall
import (
"errors"
"context"
"fmt"
"os"
@@ -15,8 +15,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
@@ -34,65 +32,54 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
var fm firewall.Manager
var errFw error
if !iface.IsUserspaceBind() {
return fm, err
}
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
}
if err = fm.Init(stateManager); err != nil {
return nil, fmt.Errorf("init firewall: %s", err)
}
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
switch check() {
case IPTABLES:
log.Info("creating an iptables firewall manager")
return nbiptables.Create(iface)
fm, errFw = nbiptables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
}
case NFTABLES:
log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface)
fm, errFw = nbnftables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw)
}
default:
errFw = fmt.Errorf("no firewall manager found")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found")
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
} else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
}
if errUsp != nil {
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
if iface.IsUserspaceBind() {
var errUsp error
if errFw == nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
}
if errUsp != nil {
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
return nil, errUsp
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
if errFw != nil {
return nil, errFw
}
return fm, nil
}

View File

@@ -1,18 +1,11 @@
package firewall
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
import "github.com/netbirdio/netbird/iface"
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() wgaddr.Address
Address() iface.WGAddress
IsUserspaceBind() bool
SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
SetFilter(iface.PacketFilter) error
}

View File

@@ -3,7 +3,7 @@ package iptables
import (
"fmt"
"net"
"slices"
"strconv"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
@@ -11,91 +11,83 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
tableName = "filter"
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
postRoutingMark = "0x000007e4"
)
type aclEntries map[string][][]string
type entry struct {
spec []string
position int
}
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routeingFwChainName string
stateManager *statemanager.Manager
entries map[string][][]string
ipsetStore *ipsetStore
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
m := &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
iptablesClient: iptablesClient,
wgIface: wgIface,
routeingFwChainName: routeingFwChainName,
entries: make(map[string][][]string),
ipsetStore: newIpsetStore(),
}
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("failed to init ipset: %w", err)
}
m.seedInitialEntries()
err = m.cleanChains()
if err != nil {
return nil, err
}
err = m.createDefaultChains()
if err != nil {
return nil, err
}
return m, nil
}
func (m *aclManager) init(stateManager *statemanager.Manager) error {
m.stateManager = stateManager
m.seedInitialEntries()
m.seedInitialOptionalEntries()
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
if err := m.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
m.updateState()
return nil
}
func (m *aclManager) AddPeerFiltering(
id []byte,
func (m *aclManager) AddFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
chain := chainNameInputRules
var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0])
}
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
}
ipsetName = transformIPsetName(ipsetName, sPort, dPort)
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
var chain string
if direction == firewall.RuleDirectionOUT {
chain = chainNameOutputRules
} else {
chain = chainNameInputRules
}
mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
"-i", m.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
specs = append(specs, "-j", actionToStr(action))
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil {
@@ -127,7 +119,7 @@ func (m *aclManager) AddPeerFiltering(
m.ipsetStore.addIpList(ipsetName, ipList)
}
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
@@ -135,36 +127,40 @@ func (m *aclManager) AddPeerFiltering(
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil {
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
return nil, err
}
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
log.Errorf("failed to add mangle rule: %v", err)
mangleSpecs = nil
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
mangleSpecs: mangleSpecs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
}
m.updateState()
if !shouldAddToPrerouting(protocol, dPort, direction) {
return []firewall.Rule{rule}, nil
}
return []firewall.Rule{rule}, nil
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
if err != nil {
return []firewall.Rule{rule}, err
}
return []firewall.Rule{rule, rulePrerouting}, nil
}
// DeletePeerRule from the firewall by rule definition
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
// DeleteRule from the firewall by rule definition
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if r.chain == "PREROUTING" {
goto DELETERULE
}
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
@@ -189,34 +185,86 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
}
}
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
DELETERULE:
var table string
if r.chain == "PREROUTING" {
table = "mangle"
} else {
table = "filter"
}
if r.mangleSpecs != nil {
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
if err != nil {
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
}
m.updateState()
return nil
return err
}
func (m *aclManager) Reset() error {
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
return m.cleanChains()
}
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
var src []string
if ipsetName != "" {
src = []string{"-m", "set", "--set", ipsetName, "src"}
} else {
src = []string{"-s", ip.String()}
}
specs := []string{
"-d", m.wgIface.Address().IP.String(),
"-p", protocol,
"--dport", port,
"-j", "MARK", "--set-mark", postRoutingMark,
}
m.updateState()
specs = append(src, specs...)
return nil
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
return nil, err
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: "PREROUTING",
}
return rule, nil
}
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
rules := m.entries["OUTPUT"]
for _, rule := range rules {
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
@@ -245,7 +293,8 @@ func (m *aclManager) cleanChains() error {
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
return fmt.Errorf("list chains: %w", err)
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
@@ -254,6 +303,11 @@ func (m *aclManager) cleanChains() error {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
return err
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
@@ -276,109 +330,109 @@ func (m *aclManager) createDefaultChains() error {
return err
}
for chainName, rules := range m.entries {
for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
}
// chain netbird-acl-output-rules
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
return err
}
for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
for chainName, rules := range m.entries {
for _, rule := range rules {
if chainName == "FORWARD" {
// position 2 because we add it after router's, jump rule
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
} else {
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
}
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)
return nil
}
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
// We want to make sure our traffic is not dropped by existing rules.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished()
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
// Inbound is handled by our ACLs, the rest is dropped.
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("FORWARD",
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD",
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
}
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2,
},
}
m.appendToEntries("PREROUTING",
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}
func (m *aclManager) updateState() {
if m.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := m.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
matchByIP = false
}
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
switch direction {
case firewall.RuleDirectionIN:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
case firewall.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
specs = append(specs, applyPort("--sport", sPort)...)
specs = append(specs, applyPort("--dport", dPort)...)
return specs
if sPort != "" {
specs = append(specs, "--sport", sPort)
}
if dPort != "" {
specs = append(specs, "--dport", dPort)
}
return append(specs, "-j", actionToStr(action))
}
func actionToStr(action firewall.Action) string {
@@ -388,17 +442,32 @@ func actionToStr(action firewall.Action) string {
return "DROP"
}
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string {
func transformIPsetName(ipsetName string, sPort, dPort string) string {
switch {
case ipsetName == "":
return ""
case sPort != nil && dPort != nil:
case sPort != "" && dPort != "":
return ipsetName + "-sport-dport"
case sPort != nil:
case sPort != "":
return ipsetName + "-sport"
case dPort != nil:
case dPort != "":
return ipsetName + "-dport"
default:
return ipsetName
}
}
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil {
return false
}
return true
}

View File

@@ -4,17 +4,13 @@ import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/iface"
)
// Manager of iptables firewall
@@ -25,21 +21,21 @@ type Manager struct {
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *router
router *routerManager
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
Address() iface.WGAddress
IsUserspaceBind() bool
}
// Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("init iptables: %w", err)
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
}
m := &Manager{
@@ -47,147 +43,80 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient,
}
m.router, err = newRouter(iptablesClient, wgIface)
m.router, err = newRouterManager(context, iptablesClient)
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
log.Debugf("failed to initialize route related chains: %s", err)
return nil, err
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
log.Debugf("failed to initialize ACL manager: %s", err)
return nil, err
}
return m, nil
}
func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}
stateManager.RegisterState(state)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err)
}
if err := m.router.init(stateManager); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
return nil
}
// AddPeerFiltering adds a rule to the firewall
// AddFiltering rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering(
id []byte,
func (m *Manager) AddFiltering(
ip net.IP,
proto firewall.Protocol,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.DeletePeerRule(rule)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule)
return m.aclMgr.DeleteRule(rule)
}
func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddNatRule(pair)
return m.router.InsertRoutingRules(pair)
}
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair)
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
return m.router.RemoveRoutingRules(pair)
}
// Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error {
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
var merr *multierror.Error
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
errAcl := m.aclMgr.Reset()
if errAcl != nil {
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
}
if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
errMgr := m.router.Reset()
if errMgr != nil {
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
return errMgr
}
// attempt to delete state only if all other operations succeeded
if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
return errAcl
}
// AllowNetbird allows netbird interface traffic
@@ -196,53 +125,31 @@ func (m *Manager) AllowNetbird() error {
return nil
}
_, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0},
_, err := m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionIN,
firewall.ActionAccept,
"",
"",
)
if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err)
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
return nil
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionOUT,
firewall.ActionAccept,
"",
"",
)
return err
}
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// SetLogLevel sets the log level for the firewall manager
func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -1,6 +1,7 @@
package iptables
import (
"context"
"fmt"
"net"
"testing"
@@ -10,28 +11,13 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/iface"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() wgaddr.Address
AddressFunc func() iface.WGAddress
}
func (i *iFaceMock) Name() string {
@@ -41,7 +27,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set")
}
func (i *iFaceMock) Address() wgaddr.Address {
func (i *iFaceMock) Address() iface.WGAddress {
if i.AddressFunc != nil {
return i.AddressFunc()
}
@@ -54,71 +40,12 @@ func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
// just check on the local interface
manager, err := Create(ifaceMock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second)
defer func() {
err := manager.Close(nil)
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
}()
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
IsRange: true,
Values: []uint16{8043, 8046},
}
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
})
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
})
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule")
err = manager.Close(nil)
require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
require.NoError(t, err, "failed check chain exists")
if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
}
})
}
func TestIptablesManagerIPSet(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -129,26 +56,143 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second)
defer func() {
err := manager.Close(nil)
err := manager.Reset()
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
}()
var rule1 []fw.Rule
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
}
})
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []uint16{443},
Values: []int{8043: 8046},
}
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
})
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
}
})
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
}
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
})
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}}
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule")
err = manager.Reset()
require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
require.NoError(t, err, "failed check chain exists")
if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
}
})
}
func TestIptablesManagerIPSet(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
time.Sleep(time.Second)
defer func() {
err := manager.Reset()
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
}()
var rule1 []fw.Rule
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
}
})
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []int{443},
}
rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -156,9 +200,18 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
})
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
}
})
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
@@ -166,7 +219,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
})
t.Run("reset check", func(t *testing.T) {
err = manager.Close(nil)
err = manager.Reset()
require.NoError(t, err, "failed to reset")
})
}
@@ -184,8 +237,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -198,13 +251,12 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second)
defer func() {
err := manager.Close(nil)
err := manager.Reset()
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
@@ -215,8 +267,12 @@ func TestIptablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")
}

File diff suppressed because it is too large Load Diff

View File

@@ -3,18 +3,16 @@
package iptables
import (
"fmt"
"net/netip"
"context"
"os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/assert"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
)
func isIptablesSupported() bool {
@@ -30,47 +28,45 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
defer func() {
assert.NoError(t, manager.Reset(), "shouldn't return error")
_ = manager.Reset()
}()
// Now 5 rules:
// 1. established rule forward in
// 2. estbalished rule forward out
// 3. jump rule to POST nat chain
// 4. jump rule to PRE mangle chain
// 5. jump rule to PRE nat chain
// 6. static outbound masquerade rule
// 7. static return masquerade rule
require.Len(t, manager.rules, 7, "should have created rules map")
require.Len(t, manager.rules, 2, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting jump rule should exist")
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
require.True(t, exists, "postrouting rule should exist")
pair := firewall.RouterPair{
ID: "abc",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.100.0/24"),
Source: "100.100.100.1/32",
Destination: "100.100.100.0/24",
Masquerade: true,
}
forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
err = manager.AddNatRule(pair)
require.NoError(t, err, "adding NAT rule should not return error")
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
}
func TestIptablesManager_AddNatRule(t *testing.T) {
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
@@ -80,71 +76,78 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
assert.NoError(t, manager.Reset(), "shouldn't return error")
err := manager.Reset()
if err != nil {
log.Errorf("failed to reset iptables manager: %s", err)
}
}()
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "marking rule should be inserted")
err = manager.InsertRoutingRules(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "forwarding rule should exist")
foundRule, found := manager.rules[forwardRuleKey]
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "marking rule should be created")
foundRule, found := manager.rules[natRuleKey]
require.True(t, found, "marking rule should exist in the map")
require.Equal(t, markingRule, foundRule, "stored marking rule should match")
require.True(t, exists, "nat rule should be created")
foundNatRule, foundNat := manager.rules[natRuleKey]
require.True(t, foundNat, "nat rule should exist in the map")
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "marking rule should not be created")
_, found := manager.rules[natRuleKey]
require.False(t, found, "marking rule should not exist in the map")
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
}
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "inverse marking rule should be created")
foundRule, found := manager.rules[inverseRuleKey]
require.True(t, found, "inverse marking rule should exist in the map")
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else {
require.False(t, exists, "inverse marking rule should not be created")
_, found := manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
}
})
}
}
func TestIptablesManager_RemoveNatRule(t *testing.T) {
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
@@ -153,226 +156,72 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
assert.NoError(t, manager.Reset(), "shouldn't return error")
_ = manager.Reset()
}()
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule without error")
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "marking rule should not exist")
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
_, found := manager.rules[natRuleKey]
require.False(t, found, "marking rule should not exist in the manager map")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
// Check inverse rule removal
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "inverse marking rule should not exist")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "forwarding rule should not exist")
_, found := manager.rules[forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "income nat rule should not exist")
_, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
_, found = manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
})
}
}
func TestRouter_AddRouteFiltering(t *testing.T) {
if !isIptablesSupported() {
t.Skip("iptables not supported on this system")
}
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))
defer func() {
err := r.Reset()
require.NoError(t, err, "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []uint16{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with multiple sources",
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []uint16{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")
// Verify rule content
params := routeFilteringRuleParams{
Sources: tt.sources,
Destination: tt.destination,
Proto: tt.proto,
SPort: tt.sPort,
DPort: tt.dPort,
Action: tt.action,
SetName: "",
}
expectedRule := genRouteFilteringRuleSpec(params)
if tt.expectSet {
setName := firewall.GenerateSetName(tt.sources)
params.SetName = setName
expectedRule = genRouteFilteringRuleSpec(params)
// Check if the set was created
_, exists := r.ipsetCounter.Get(setName)
assert.True(t, exists, "IPSet not created")
}
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
})
}
}

View File

@@ -5,13 +5,12 @@ type Rule struct {
ruleID string
ipsetName string
specs []string
mangleSpecs []string
ip string
chain string
specs []string
ip string
chain string
}
// GetRuleID returns the rule id
func (r *Rule) ID() string {
func (r *Rule) GetRuleID() string {
return r.ruleID
}

View File

@@ -1,16 +1,14 @@
package iptables
import "encoding/json"
type ipList struct {
ips map[string]struct{}
}
func newIpList(ip string) *ipList {
func newIpList(ip string) ipList {
ips := make(map[string]struct{})
ips[ip] = struct{}{}
return &ipList{
return ipList{
ips: ips,
}
}
@@ -19,52 +17,27 @@ func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPs map[string]struct{} `json:"ips"`
}{
IPs: s.ips,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipList) UnmarshalJSON(data []byte) error {
temp := struct {
IPs map[string]struct{} `json:"ips"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ips = temp.IPs
if temp.IPs == nil {
temp.IPs = make(map[string]struct{})
}
return nil
}
type ipsetStore struct {
ipsets map[string]*ipList
ipsets map[string]ipList // ipsetName -> ruleset
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsets: make(map[string]*ipList),
ipsets: make(map[string]ipList),
}
}
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
s.ipsets[ipsetName] = list
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
s.ipsets[ipsetName] = ipList{}
delete(s.ipsets, ipsetName)
}
@@ -75,29 +48,3 @@ func (s *ipsetStore) ipsetNames() []string {
}
return names
}
// MarshalJSON implements json.Marshaler
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPSets map[string]*ipList `json:"ipsets"`
}{
IPSets: s.ipsets,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
temp := struct {
IPSets map[string]*ipList `json:"ipsets"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ipsets = temp.IPSets
if temp.IPSets == nil {
temp.IPSets = make(map[string]*ipList)
}
return nil
}

View File

@@ -1,69 +0,0 @@
package iptables
import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
}
func (s *ShutdownState) Name() string {
return "iptables_state"
}
func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create iptables manager: %w", err)
}
if s.RouteRules != nil {
ipt.router.rules = s.RouteRules
}
if s.RouteIPsetCounter != nil {
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
}
if s.ACLEntries != nil {
ipt.aclMgr.entries = s.ACLEntries
}
if s.ACLIPsetStore != nil {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
if err := ipt.Close(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err)
}
return nil
}

View File

@@ -1,24 +1,15 @@
package manager
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"net/netip"
"sort"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t"
PreroutingFormat = "netbird-prerouting-%s-%t"
NatFormat = "netbird-nat-%s-%t"
NatFormat = "netbird-nat-%s"
ForwardingFormat = "netbird-fwd-%s"
InNatFormat = "netbird-nat-in-%s"
InForwardingFormat = "netbird-fwd-in-%s"
)
// Rule abstraction should be implemented by each firewall manager
@@ -26,8 +17,8 @@ const (
// Each firewall type for different OS can use different type
// of the properties to hold data of the created rule
type Rule interface {
// ID returns the rule id
ID() string
// GetRuleID returns the rule id
GetRuleID() string
}
// RuleDirection is the traffic direction which a rule is applied
@@ -55,155 +46,43 @@ const (
// It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality
type Manager interface {
Init(stateManager *statemanager.Manager) error
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddPeerFiltering adds a rule to the firewall
// AddFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
AddPeerFiltering(
id []byte,
AddFiltering(
ip net.IP,
proto Protocol,
sPort *Port,
dPort *Port,
direction RuleDirection,
action Action,
ipsetName string,
comment string,
) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition
DeletePeerRule(rule Rule) error
// DeleteRule from the firewall by rule definition
DeleteRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
) (Rule, error)
// InsertRoutingRules inserts a routing firewall rule
InsertRoutingRules(pair RouterPair) error
// DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error
// RemoveRoutingRules removes a routing firewall rule
RemoveRoutingRules(pair RouterPair) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair RouterPair) error
// RemoveNatRule removes a routing NAT rule
RemoveNatRule(pair RouterPair) error
// SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error
// Close closes the firewall manager
Close(stateManager *statemanager.Manager) error
// Reset firewall to the default state
Reset() error
// Flush the changes to firewall controller
Flush() error
SetLogLevel(log.Level)
EnableRouting() error
DisableRouting() error
// AddDNATRule adds a DNAT rule
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error
}
func GenKey(format string, pair RouterPair) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse)
}
// LegacyManager defines the interface for legacy management operations
type LegacyManager interface {
RemoveAllLegacyRouteRules() error
GetLegacyManagement() bool
SetLegacyManagement(bool)
}
// SetLegacyManagement sets the route manager to use legacy management
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
oldLegacy := router.GetLegacyManagement()
if oldLegacy != isLegacy {
router.SetLegacyManagement(isLegacy)
log.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to clean up the legacy rules
if !isLegacy && oldLegacy {
if err := router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed")
}
return nil
}
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {
return prefixes
}
merged := []netip.Prefix{prefixes[0]}
for _, prefix := range prefixes[1:] {
last := merged[len(merged)-1]
if last.Contains(prefix.Addr()) {
// If the current prefix is contained within the last merged prefix, skip it
continue
}
if prefix.Contains(last.Addr()) {
// If the current prefix contains the last merged prefix, replace it
merged[len(merged)-1] = prefix
} else {
// Otherwise, add the current prefix to the merged list
merged = append(merged, prefix)
}
}
return merged
}
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {
return addrCmp < 0
}
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
return prefixes[i].Bits() > prefixes[j].Bits()
})
func GenKey(format string, input string) string {
return fmt.Sprintf(format, input)
}

View File

@@ -1,192 +0,0 @@
package manager_test
import (
"net/netip"
"reflect"
"regexp"
"testing"
"github.com/netbirdio/netbird/client/firewall/manager"
)
func TestGenerateSetName(t *testing.T) {
t.Run("Different orders result in same hash", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
}
})
t.Run("Result format is correct", func(t *testing.T) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
result := manager.GenerateSetName(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
if err != nil {
t.Fatalf("Error matching regex: %v", err)
}
if !matched {
t.Errorf("Result format is incorrect: %s", result)
}
})
t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.GenerateSetName([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{})
if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
}
})
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("2001:db8::/32"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
}
})
}
func TestMergeIPRanges(t *testing.T) {
tests := []struct {
name string
input []netip.Prefix
expected []netip.Prefix
}{
{
name: "Empty input",
input: []netip.Prefix{},
expected: []netip.Prefix{},
},
{
name: "Single range",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Two non-overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
{
name: "One range containing another",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "One range containing another (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.0.0/16"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Overlapping ranges (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.128/25"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Multiple overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Partially overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.2.0/25"),
},
},
{
name: "IPv6 ranges",
input: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::/48"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := manager.MergeIPRanges(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
}
})
}
}

View File

@@ -1,27 +0,0 @@
package manager
import (
"fmt"
"net/netip"
)
// ForwardRule todo figure out better place to this to avoid circular imports
type ForwardRule struct {
Protocol Protocol
DestinationPort Port
TranslatedAddress netip.Addr
TranslatedPort Port
}
func (r ForwardRule) ID() string {
id := fmt.Sprintf("%s;%s;%s;%s",
r.Protocol,
r.DestinationPort.String(),
r.TranslatedAddress.String(),
r.TranslatedPort.String())
return id
}
func (r ForwardRule) String() string {
return fmt.Sprintf("protocol: %s, destinationPort: %s, translatedAddress: %s, translatedPort: %s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String())
}

View File

@@ -1,37 +1,36 @@
package manager
import (
"fmt"
"strconv"
)
// Protocol is the protocol of the port
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
// ProtocolUnknown unknown protocol
ProtocolUnknown Protocol = "unknown"
)
// Port of the address for firewall rule
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Port struct {
// IsRange is true Values contains two values, the first is the start port, the second is the end port
IsRange bool
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
Values []uint16
}
func NewPort(ports ...int) (*Port, error) {
if len(ports) == 0 {
return nil, fmt.Errorf("no port provided")
}
ports16 := make([]uint16, len(ports))
for i, port := range ports {
if port < 1 || port > 65535 {
return nil, fmt.Errorf("invalid port number: %d (must be between 1-65535)", port)
}
ports16[i] = uint16(port)
}
return &Port{
IsRange: len(ports) > 1,
Values: ports16,
}, nil
Values []int
}
// String interface implementation
@@ -41,11 +40,7 @@ func (p *Port) String() string {
if ports != "" {
ports += ","
}
ports += strconv.Itoa(int(port))
ports += strconv.Itoa(port)
}
if p.IsRange {
ports = "range:" + ports
}
return ports
}

View File

@@ -1,19 +0,0 @@
package manager
// Protocol is the protocol of the port
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
)

View File

@@ -1,26 +1,18 @@
package manager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type RouterPair struct {
ID route.ID
Source netip.Prefix
Destination netip.Prefix
ID string
Source string
Destination string
Masquerade bool
Inverse bool
}
func GetInversePair(pair RouterPair) RouterPair {
func GetInPair(pair RouterPair) RouterPair {
return RouterPair{
ID: pair.ID,
// invert Source/Destination
Source: pair.Destination,
Destination: pair.Source,
Masquerade: pair.Masquerade,
Inverse: true,
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -5,34 +5,20 @@ import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
tableNameNetbird = "netbird"
tableNameFilter = "filter"
chainNameInput = "INPUT"
// tableName is the name of the table that is used for filtering by the Netbird client
tableName = "netbird"
)
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
IsUserspaceBind() bool
}
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
@@ -44,82 +30,43 @@ type Manager struct {
}
// Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
m := &Manager{
rConn: &nftables.Conn{},
wgIface: wgIface,
}
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
var err error
m.router, err = newRouter(workTable, wgIface)
workTable, err := m.createWorkTable()
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
return nil, err
}
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
m.router, err = newRouter(context, workTable)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
return nil, err
}
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
if err != nil {
return nil, err
}
return m, nil
}
// Init nftables firewall manager
func (m *Manager) Init(stateManager *statemanager.Manager) error {
workTable, err := m.createWorkTable()
if err != nil {
return fmt.Errorf("create work table: %w", err)
}
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
}
// persist early
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
return nil
}
// AddPeerFiltering rule to the firewall
// AddFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
id []byte,
func (m *Manager) AddFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -129,60 +76,33 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
}
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
}
func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclManager.DeletePeerRule(rule)
}
// DeleteRouteRule deletes a routing rule
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule)
return m.aclManager.DeleteRule(rule)
}
func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddNatRule(pair)
return m.router.AddRoutingRules(pair)
}
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair)
return m.router.RemoveRoutingRules(pair)
}
// AllowNetbird allows netbird interface traffic
@@ -206,7 +126,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
if c.Table.Name == "filter" && c.Name == "INPUT" {
chain = c
break
}
@@ -237,99 +157,47 @@ func (m *Manager) AllowNetbird() error {
return nil
}
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
}
// Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error {
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.resetNetbirdInputRules(); err != nil {
return fmt.Errorf("reset netbird input rules: %v", err)
}
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err)
}
if err := m.cleanupNetbirdTables(); err != nil {
return fmt.Errorf("cleanup netbird tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
return fmt.Errorf("delete state: %v", err)
}
return nil
}
func (m *Manager) resetNetbirdInputRules() error {
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list chains: %w", err)
return fmt.Errorf("list of chains: %w", err)
}
m.deleteNetbirdInputRules(chains)
return nil
}
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
m.deleteMatchingRules(rules)
}
}
}
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
}
}
func (m *Manager) cleanupNetbirdTables() error {
m.router.ResetForwardRules()
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list tables: %w", err)
return fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == tableNameNetbird {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}
return nil
}
// SetLogLevel sets the log level for the firewall manager
func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
return m.rConn.Flush()
}
// Flush rule/chain/set operations from the buffer
@@ -343,22 +211,6 @@ func (m *Manager) Flush() error {
return m.aclManager.Flush()
}
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
@@ -366,12 +218,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
}
for _, t := range tables {
if t.Name == tableNameNetbird {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush()
return table, err
}
@@ -399,7 +251,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
@@ -413,38 +265,3 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
}
return nil
}
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: getEstablishedExprs(1),
}
conn.InsertRule(rule)
}
func getEstablishedExprs(register uint32) []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: register,
},
&expr.Bitwise{
SourceRegister: register,
DestRegister: register,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: register,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
}

View File

@@ -1,43 +1,26 @@
package nftables
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"os/exec"
"testing"
"time"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/iface"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() wgaddr.Address
AddressFunc func() iface.WGAddress
}
func (i *iFaceMock) Name() string {
@@ -47,7 +30,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set")
}
func (i *iFaceMock) Address() wgaddr.Address {
func (i *iFaceMock) Address() iface.WGAddress {
if i.AddressFunc != nil {
return i.AddressFunc()
}
@@ -57,15 +40,28 @@ func (i *iFaceMock) Address() wgaddr.Address {
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface
manager, err := Create(ifaceMock)
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3)
defer func() {
err = manager.Close(nil)
err = manager.Reset()
require.NoError(t, err, "failed to reset")
time.Sleep(time.Second)
}()
@@ -74,7 +70,16 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
rule, err := manager.AddFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{53}},
fw.RuleDirectionIN,
fw.ActionDrop,
"",
"",
)
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
@@ -83,35 +88,17 @@ func TestNftablesManager(t *testing.T) {
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 2, "expected 2 rules")
expectedExprs1 := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
require.Len(t, rules, 1, "expected 1 rules")
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{
expectedExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname("lo"),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
@@ -147,10 +134,10 @@ func TestNftablesManager(t *testing.T) {
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions")
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
for _, r := range rule {
err = manager.DeletePeerRule(r)
err = manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
}
@@ -159,10 +146,9 @@ func TestNftablesManager(t *testing.T) {
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
// established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion")
require.Len(t, rules, 0, "expected 0 rules after deletion")
err = manager.Close(nil)
err = manager.Reset()
require.NoError(t, err, "failed to reset")
}
@@ -171,8 +157,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
@@ -185,13 +171,12 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3)
defer func() {
if err := manager.Close(nil); err != nil {
if err := manager.Reset(); err != nil {
t.Errorf("clear the manager state: %v", err)
}
time.Sleep(time.Second)
@@ -200,8 +185,12 @@ func TestNFtablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@@ -214,112 +203,3 @@ func TestNFtablesCreatePerformance(t *testing.T) {
})
}
}
func runIptablesSave(t *testing.T) (string, string) {
t.Helper()
var stdout, stderr bytes.Buffer
cmd := exec.Command("iptables-save")
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
require.NoError(t, err, "iptables-save failed to run")
return stdout.String(), stderr.String()
}
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
t.Helper()
// Check for any incompatibility warnings
require.NotContains(t,
stderr,
"incompatible",
"iptables-save produced compatibility warning. Full stderr: %s",
stderr,
)
// Verify standard tables are present
expectedTables := []string{
"*filter",
"*nat",
"*mangle",
}
for _, table := range expectedTables {
require.Contains(t,
stdout,
table,
"iptables-save output missing expected table: %s\nFull stdout: %s",
table,
stdout,
)
}
}
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"),
Destination: netip.MustParsePrefix("10.0.0.0/24"),
Masquerade: true,
}
err = manager.AddNatRule(pair)
require.NoError(t, err, "failed to add NAT rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
t.Helper()
require.Equal(t, len(got), len(want), "expression count mismatch")
for i := range got {
if _, isCounter := got[i].(*expr.Counter); isCounter {
_, wantIsCounter := want[i].(*expr.Counter)
require.True(t, wantIsCounter, "expected Counter at index %d", i)
continue
}
require.Equal(t, got[i], want[i], "expression mismatch at index %d", i)
}
}

View File

@@ -0,0 +1,431 @@
package nftables
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/manager"
)
const (
chainNameRouteingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
exprCounterAccept = []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
isDefaultFwdRulesEnabled bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, err
}
}
err = r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return r, err
}
func (r *router) RouteingFwChainName() string {
return chainNameRouteingFw
}
// ResetForwardRules cleans existing nftables default forward rules from the system
func (r *router) ResetForwardRules() {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to reset forward rules: %s", err)
}
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRouteingFw,
Table: r.workTable,
})
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
}
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
log.Debugf("add default accept forward rule")
r.acceptForwardRule(pair.Source)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// addRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
} else {
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
}
ruleKey := manager.GenKey(format, pair.ID)
_, exists := r.rules[ruleKey]
if exists {
err := r.removeRoutingRule(format, pair)
if err != nil {
return err
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
func (r *router) acceptForwardRule(sourceNetwork string) {
src := generateCIDRMatcherExpressions(true, sourceNetwork)
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
r.conn.AddRule(rule)
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule = &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
r.conn.AddRule(rule)
r.isDefaultFwdRulesEnabled = true
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
if err != nil {
return err
}
err = r.removeRoutingRule(manager.NatFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
if err != nil {
return err
}
if len(r.rules) == 0 {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.Destination)
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
ruleKey := manager.GenKey(format, pair.ID)
rule, found := r.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
delete(r.rules, ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
r.isDefaultFwdRulesEnabled = false
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return err
}
var rules []*nftables.Rule
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != "FORWARD" {
continue
}
rules, err = r.conn.GetRules(r.filterTable, chain)
if err != nil {
return err
}
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
err := r.conn.DelRule(rule)
if err != nil {
return err
}
}
}
r.isDefaultFwdRulesEnabled = false
return r.conn.Flush()
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
ip, network, _ := net.ParseCIDR(cidr)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
var offSet uint32
if source {
offSet = 12 // src offset
} else {
offSet = 16 // dst offset
}
return []expr.Any{
// fetch src add
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: 4,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: network.Mask,
Xor: zeroXor,
},
// net address
&expr.Cmp{
Register: 1,
Data: add.AsSlice(),
},
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -3,16 +3,12 @@
package nftables
import (
"encoding/binary"
"net/netip"
"os/exec"
"context"
"testing"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -28,631 +24,195 @@ const (
NFTABLES
)
func TestNftablesManager_AddNatRule(t *testing.T) {
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
manager, err := newRouter(context.TODO(), table)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
rtr := manager.router
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted")
defer manager.ResetForwardRules()
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
})
require.NoError(t, err, "shouldn't return error")
err = manager.AddRoutingRules(testCase.InputPair)
defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()
require.NoError(t, err, "forwarding pair should be inserted")
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
// Build expected expressions for connection tracking
conntrackExprs := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
// Build interface matching expression
ifaceExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
}
// Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
// Combine all expressions in the correct order
// nolint:gocritic
testingExpression := append(conntrackExprs, ifaceExprs...)
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range rtr.chains {
if chain.Name == chainNamePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
// Compare expressions up to the mark setting expressions
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1
}
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
found = 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
})
}
}
func TestNftablesManager_RemoveNatRule(t *testing.T) {
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
manager, err := newRouter(context.TODO(), table)
require.NoError(t, err, "failed to create router")
rtr := manager.router
nftablesTestingClient := &nftables.Conn{}
// First add the NAT rule using the router's method
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule")
defer manager.ResetForwardRules()
// Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
}
require.True(t, found, "NAT rule should exist before removal")
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
// Now remove the rule
err = rtr.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error when removing rule")
// Verify the rule was removed
found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
}
require.False(t, found, "NAT rule should not exist after removal")
// Verify the static postrouting rules still exist
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
require.NoError(t, err, "should list postrouting rules")
foundCounter := false
for _, rule := range rules {
for _, e := range rule.Exprs {
if _, ok := e.(*expr.Counter); ok {
foundCounter = true
break
}
}
if foundCounter {
break
}
}
require.True(t, foundCounter, "static postrouting rule should remain")
})
}
}
func TestRouter_AddRouteFiltering(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func(r *router) {
require.NoError(t, r.Reset(), "Failed to reset rules")
}(r)
tests := []struct {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []uint16{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with multiple sources",
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []uint16{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
})
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
t.Log("Internal rule expressions:")
for i, expr := range rule.Exprs {
t.Logf(" [%d] %T: %+v", i, expr, expr)
}
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(natRuleKey),
})
// Verify internal rule content
verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
// Check if the rule exists in nftables and verify its content
rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
require.NoError(t, err, "Failed to get rules from nftables")
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
var nftRule *nftables.Rule
for _, rule := range rules {
if string(rule.UserData) == ruleKey.ID() {
nftRule = rule
break
}
}
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
require.NotNil(t, nftRule, "Rule not found in nftables")
t.Log("Actual nftables rule expressions:")
for i, expr := range nftRule.Exprs {
t.Logf(" [%d] %T: %+v", i, expr, expr)
}
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
// Verify actual nftables rule content
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
})
}
}
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
func TestNftablesCreateIpSet(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
manager.ResetForwardRules()
workTable, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
expected []netip.Prefix
}{
{
name: "Single IP",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
},
{
name: "Multiple IPs",
sources: []netip.Prefix{
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("10.0.0.1/32"),
netip.MustParsePrefix("172.16.0.1/32"),
},
},
{
name: "Single Subnet",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
{
name: "Multiple Subnets with Various Prefix Lengths",
sources: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("203.0.113.0/26"),
},
},
{
name: "Mix of Single IPs and Subnets in Different Positions",
sources: []netip.Prefix{
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("10.0.0.0/16"),
netip.MustParsePrefix("172.16.0.1/32"),
netip.MustParsePrefix("203.0.113.0/24"),
},
},
{
name: "Overlapping IPs/Subnets",
sources: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("10.0.0.0/16"),
netip.MustParsePrefix("10.0.0.1/32"),
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.1/32"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.0.0/16"),
},
},
}
// Add this helper function inside TestNftablesCreateIpSet
printNftSets := func() {
cmd := exec.Command("nft", "list", "sets")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to run 'nft list sets': %v", err)
} else {
t.Logf("Current nft sets:\n%s", output)
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.GenerateSetName(tt.sources)
set, err := r.createIpSet(setName, tt.sources)
if err != nil {
t.Logf("Failed to create IP set: %v", err)
printNftSets()
require.NoError(t, err, "Failed to create IP set")
}
require.NotNil(t, set, "Created set is nil")
// Verify set properties
assert.Equal(t, setName, set.Name, "Set name mismatch")
assert.Equal(t, r.workTable, set.Table, "Set table mismatch")
assert.True(t, set.Interval, "Set interval property should be true")
assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch")
// Fetch the created set from nftables
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
require.NoError(t, err, "Failed to fetch created set")
require.NotNil(t, fetchedSet, "Fetched set is nil")
// Verify set elements
elements, err := r.conn.GetSetElements(fetchedSet)
require.NoError(t, err, "Failed to get set elements")
// Count the number of unique prefixes (excluding interval end markers)
uniquePrefixes := make(map[string]bool)
for _, elem := range elements {
if !elem.IntervalEnd {
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
uniquePrefixes[ip.String()] = true
}
}
// Check against expected merged prefixes
expectedCount := len(tt.expected)
if expectedCount == 0 {
expectedCount = len(tt.sources)
}
assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected")
// Verify each expected prefix is in the set
for _, expected := range tt.expected {
found := false
for _, elem := range elements {
if !elem.IntervalEnd {
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
if expected.Contains(ip) {
found = true
break
}
}
}
assert.True(t, found, "Expected prefix %s not found in set", expected)
}
r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil {
t.Logf("Failed to delete set: %v", err)
printNftSets()
}
require.NoError(t, err, "Failed to delete set")
})
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
t.Helper()
assert.NotNil(t, rule, "Rule should not be nil")
// Verify sources and destination
if expectSet {
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
} else if len(sources) == 1 && sources[0].Bits() != 0 {
if direction == firewall.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
}
}
if direction == firewall.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
}
// Verify protocol
if proto != firewall.ProtocolALL {
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
}
// Verify ports
if sPort != nil {
assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort)
}
if dPort != nil {
assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort)
}
// Verify action
assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action)
}
func containsSetLookup(exprs []expr.Any) bool {
for _, e := range exprs {
if _, ok := e.(*expr.Lookup); ok {
return true
}
}
return false
}
func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool {
var offset uint32
if isSource {
offset = 12 // src offset
} else {
offset = 16 // dst offset
}
var payloadFound, bitwiseFound, cmpFound bool
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Payload:
if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 {
payloadFound = true
}
case *expr.Bitwise:
if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 {
bitwiseFound = true
}
case *expr.Cmp:
if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 {
cmpFound = true
}
}
}
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
}
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
var offset uint32 = 2 // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
}
var payloadFound, portMatchFound bool
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Payload:
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
payloadFound = true
}
case *expr.Range:
if port.IsRange && len(port.Values) == 2 {
fromPort := binary.BigEndian.Uint16(ex.FromData)
toPort := binary.BigEndian.Uint16(ex.ToData)
if fromPort == port.Values[0] && toPort == port.Values[1] {
portMatchFound = true
}
}
case *expr.Cmp:
if !port.IsRange {
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
portValue := binary.BigEndian.Uint16(ex.Data)
for _, p := range port.Values {
if p == portValue {
portMatchFound = true
break
}
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
}
}
}
if payloadFound && portMatchFound {
return true
}
})
}
return false
}
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
var metaFound, cmpFound bool
expectedProto, _ := protoToInt(proto)
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Meta:
if ex.Key == expr.MetaKeyL4PROTO {
metaFound = true
}
case *expr.Cmp:
if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto {
cmpFound = true
}
}
}
return metaFound && cmpFound
}
func containsAction(exprs []expr.Any, action firewall.Action) bool {
for _, e := range exprs {
if verdict, ok := e.(*expr.Verdict); ok {
switch action {
case firewall.ActionAccept:
return verdict.Kind == expr.VerdictAccept
case firewall.ActionDrop:
return verdict.Kind == expr.VerdictDrop
}
}
}
return false
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
@@ -690,12 +250,12 @@ func createWorkTable() (*nftables.Table, error) {
}
for _, t := range tables {
if t.Name == tableNameNetbird {
if t.Name == tableName {
sConn.DelTable(t)
}
}
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = sConn.Flush()
return table, err
@@ -713,7 +273,7 @@ func deleteWorkTable() {
}
for _, t := range tables {
if t.Name == tableNameNetbird {
if t.Name == tableName {
sConn.DelTable(t)
}
}

View File

@@ -8,14 +8,13 @@ import (
// Rule to handle management of rules
type Rule struct {
nftRule *nftables.Rule
mangleRule *nftables.Rule
nftSet *nftables.Set
ruleID string
ip net.IP
nftRule *nftables.Rule
nftSet *nftables.Set
ruleID string
ip net.IP
}
// GetRuleID returns the rule id
func (r *Rule) ID() string {
func (r *Rule) GetRuleID() string {
return r.ruleID
}

View File

@@ -1,46 +0,0 @@
package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}
func (s *ShutdownState) Name() string {
return "nftables_state"
}
func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create nftables manager: %w", err)
}
if err := nft.Close(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err)
}
return nil
}

View File

@@ -1,10 +1,8 @@
//go:build !android
package test
import (
"net/netip"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
var (
InsertRuleTestCases = []struct {
@@ -15,8 +13,8 @@ var (
Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.200.0/24"),
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Masquerade: false,
},
},
@@ -24,8 +22,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.200.0/24"),
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Masquerade: true,
},
},
@@ -40,8 +38,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.200.0/24"),
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Masquerade: true,
},
},

View File

@@ -2,50 +2,16 @@
package uspfilter
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error {
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)
return m.nativeFirewall.Reset()
}
return nil
}

View File

@@ -1,17 +1,11 @@
package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type action string
@@ -23,39 +17,12 @@ const (
)
// Reset firewall to the default state
func (m *Manager) Close(*statemanager.Manager) error {
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if !isWindowsFirewallReachable() {
return nil

View File

@@ -1,16 +0,0 @@
package common
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() wgaddr.Address
GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice
}

View File

@@ -1,66 +0,0 @@
package conntrack
import (
"fmt"
"net/netip"
"sync/atomic"
"time"
"github.com/google/uuid"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
FlowId uuid.UUID
Direction nftypes.Direction
SourceIP netip.Addr
DestIP netip.Addr
lastSeen atomic.Int64
PacketsTx atomic.Uint64
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
}
// these small methods will be inlined by the compiler
// UpdateLastSeen safely updates the last seen timestamp
func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano())
}
// UpdateCounters safely updates the packet and byte counters
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
if direction == nftypes.Egress {
b.PacketsTx.Add(1)
b.BytesTx.Add(uint64(bytes))
} else {
b.PacketsRx.Add(1)
b.BytesRx.Add(uint64(bytes))
}
}
// GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load())
}
// timeoutExceeded checks if the connection has exceeded the given timeout
func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
lastSeen := time.Unix(0, b.lastSeen.Load())
return time.Since(lastSeen) > timeout
}
// ConnKey uniquely identifies a connection
type ConnKey struct {
SrcIP netip.Addr
DstIP netip.Addr
SrcPort uint16
DstPort uint16
}
func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
}

View File

@@ -1,68 +0,0 @@
package conntrack
import (
"context"
"net/netip"
"testing"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/internal/netflow"
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
// Generate different IPs
srcIPs := make([]netip.Addr, 100)
dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
}
}
})
b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
// Generate different IPs
srcIPs := make([]netip.Addr, 100)
dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
}
}
})
}

View File

@@ -1,245 +0,0 @@
package conntrack
import (
"context"
"fmt"
"net/netip"
"sync"
"time"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
// DefaultICMPTimeout is the default timeout for ICMP connections
DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second
)
// ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct {
SrcIP netip.Addr
DstIP netip.Addr
ID uint16
}
func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
}
// ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct {
BaseConnTrack
ICMPType uint8
ICMPCode uint8
}
// ICMPTracker manages ICMP connection states
type ICMPTracker struct {
logger *nblog.Logger
connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
flowLogger nftypes.FlowLogger
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 {
timeout = DefaultICMPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &ICMPTracker{
logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
tickerCancel: cancel,
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
return tracker
}
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
key := ICMPConnKey{
SrcIP: srcIP,
DstIP: dstIP,
ID: id,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists {
return
}
typ, code := typecode.Type(), typecode.Code()
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
conn := &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
ICMPType: typ,
ICMPCode: code,
}
conn.UpdateLastSeen()
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
key := ICMPConnKey{
SrcIP: dstIP,
DstIP: srcIP,
ID: id,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) {
return false
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true
}
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
defer t.tickerCancel()
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-ctx.Done():
return
}
}
}
func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.tickerCancel()
t.mutex.Lock()
t.connections = nil
t.mutex.Unlock()
}
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
fields := nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction,
Protocol: nftypes.ICMP,
SourceIP: srcIP,
DestIP: dstIP,
ICMPType: typ,
ICMPCode: code,
}
if direction == nftypes.Ingress {
fields.RxPackets = 1
fields.RxBytes = uint64(size)
} else {
fields.TxPackets = 1
fields.TxBytes = uint64(size)
}
t.flowLogger.StoreEvent(fields)
}

View File

@@ -1,39 +0,0 @@
package conntrack
import (
"net/netip"
"testing"
)
func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
}
})
}

View File

@@ -1,490 +0,0 @@
package conntrack
// TODO: Send RST packets for invalid/timed-out connections
import (
"context"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
// MSL (Maximum Segment Lifetime) is typically 2 minutes
MSL = 2 * time.Minute
// TimeWaitTimeout (TIME-WAIT) should last 2*MSL
TimeWaitTimeout = 2 * MSL
)
const (
TCPSyn uint8 = 0x02
TCPAck uint8 = 0x10
TCPFin uint8 = 0x01
TCPRst uint8 = 0x04
TCPPush uint8 = 0x08
TCPUrg uint8 = 0x20
)
const (
// DefaultTCPTimeout is the default timeout for established TCP connections
DefaultTCPTimeout = 3 * time.Hour
// TCPHandshakeTimeout is timeout for TCP handshake completion
TCPHandshakeTimeout = 60 * time.Second
// TCPCleanupInterval is how often we check for stale connections
TCPCleanupInterval = 5 * time.Minute
)
// TCPState represents the state of a TCP connection
type TCPState int
func (s TCPState) String() string {
switch s {
case TCPStateNew:
return "New"
case TCPStateSynSent:
return "SYN Sent"
case TCPStateSynReceived:
return "SYN Received"
case TCPStateEstablished:
return "Established"
case TCPStateFinWait1:
return "FIN Wait 1"
case TCPStateFinWait2:
return "FIN Wait 2"
case TCPStateClosing:
return "Closing"
case TCPStateTimeWait:
return "Time Wait"
case TCPStateCloseWait:
return "Close Wait"
case TCPStateLastAck:
return "Last ACK"
case TCPStateClosed:
return "Closed"
default:
return "Unknown"
}
}
const (
TCPStateNew TCPState = iota
TCPStateSynSent
TCPStateSynReceived
TCPStateEstablished
TCPStateFinWait1
TCPStateFinWait2
TCPStateClosing
TCPStateTimeWait
TCPStateCloseWait
TCPStateLastAck
TCPStateClosed
)
// TCPConnTrack represents a TCP connection state
type TCPConnTrack struct {
BaseConnTrack
SourcePort uint16
DestPort uint16
State TCPState
established atomic.Bool
tombstone atomic.Bool
sync.RWMutex
}
// IsEstablished safely checks if connection is established
func (t *TCPConnTrack) IsEstablished() bool {
return t.established.Load()
}
// SetEstablished safely sets the established state
func (t *TCPConnTrack) SetEstablished(state bool) {
t.established.Store(state)
}
// IsTombstone safely checks if the connection is marked for deletion
func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
}
// TCPTracker manages TCP connection states
type TCPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
timeout time.Duration
flowLogger nftypes.FlowLogger
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
if timeout == 0 {
timeout = DefaultTCPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
timeout: timeout,
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
return tracker
}
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.Unlock()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
}
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.established.Store(false)
conn.tombstone.Store(false)
t.logger.Trace("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction == nftypes.Egress)
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
return false
}
// Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
if conn.IsTombstone() {
return true
}
conn.Lock()
conn.SetTombstone()
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
return true
}
conn.Lock()
t.updateState(key, conn, flags, false)
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock()
return isEstablished || isValidState
}
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
conn.UpdateLastSeen()
state := conn.State
defer func() {
if state != conn.State {
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
}
}()
switch state {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent
}
case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound {
conn.State = TCPStateEstablished
conn.SetEstablished(true)
} else {
// Simultaneous open
conn.State = TCPStateSynReceived
}
}
case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
conn.State = TCPStateEstablished
conn.SetEstablished(true)
}
case TCPStateEstablished:
if flags&TCPFin != 0 {
if isOutbound {
conn.State = TCPStateFinWait1
} else {
conn.State = TCPStateCloseWait
}
conn.SetEstablished(false)
} else if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateFinWait1:
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
conn.State = TCPStateClosing
case flags&TCPFin != 0:
conn.State = TCPStateFinWait2
case flags&TCPAck != 0:
conn.State = TCPStateFinWait2
case flags&TCPRst != 0:
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateFinWait2:
if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateClosing:
if flags&TCPAck != 0 {
conn.State = TCPStateTimeWait
// Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateCloseWait:
if flags&TCPFin != 0 {
conn.State = TCPStateLastAck
}
case TCPStateLastAck:
if flags&TCPAck != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
// Send close event for gracefully closed connections
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("TCP connection %s closed gracefully", key)
}
}
}
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
switch state {
case TCPStateNew:
return flags&TCPSyn != 0 && flags&TCPAck == 0
case TCPStateSynSent:
return flags&TCPSyn != 0 && flags&TCPAck != 0
case TCPStateSynReceived:
return flags&TCPAck != 0
case TCPStateEstablished:
if flags&TCPRst != 0 {
return true
}
return flags&TCPAck != 0
case TCPStateFinWait1:
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateFinWait2:
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateClosing:
// In CLOSING state, we should accept the final ACK
return flags&TCPAck != 0
case TCPStateTimeWait:
// In TIME_WAIT, we might see retransmissions
return flags&TCPAck != 0
case TCPStateCloseWait:
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateLastAck:
return flags&TCPAck != 0
case TCPStateClosed:
// Accept retransmitted ACKs in closed state
// This is important because the final ACK might be lost
// and the peer will retransmit their FIN-ACK
return flags&TCPAck != 0
}
return false
}
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-ctx.Done():
return
}
}
}
func (t *TCPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.IsTombstone() {
// Clean up tombstoned connections without sending an event
delete(t.connections, key)
continue
}
var timeout time.Duration
switch {
case conn.State == TCPStateTimeWait:
timeout = TimeWaitTimeout
case conn.IsEstablished():
timeout = t.timeout
default:
timeout = TCPHandshakeTimeout
}
if conn.timeoutExceeded(timeout) {
// Return IPs to pool
delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
// event already handled by state change
if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}
}
}
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.tickerCancel()
// Clean up all remaining IPs
t.mutex.Lock()
t.connections = nil
t.mutex.Unlock()
}
func isValidFlagCombination(flags uint8) bool {
// Invalid: SYN+FIN
if flags&TCPSyn != 0 && flags&TCPFin != 0 {
return false
}
// Invalid: RST with SYN or FIN
if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) {
return false
}
return true
}
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.TCP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -1,313 +0,0 @@
package conntrack
import (
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Security Tests", func(t *testing.T) {
tests := []struct {
name string
flags uint8
wantDrop bool
desc string
}{
{
name: "Block unsolicited SYN-ACK",
flags: TCPSyn | TCPAck,
wantDrop: true,
desc: "Should block SYN-ACK without prior SYN",
},
{
name: "Block invalid SYN-FIN",
flags: TCPSyn | TCPFin,
wantDrop: true,
desc: "Should block invalid SYN-FIN combination",
},
{
name: "Block unsolicited RST",
flags: TCPRst,
wantDrop: true,
desc: "Should block RST without connection",
},
{
name: "Block unsolicited ACK",
flags: TCPAck,
wantDrop: true,
desc: "Should block ACK without connection",
},
{
name: "Block data without connection",
flags: TCPAck | TCPPush,
wantDrop: true,
desc: "Should block data without established connection",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
})
}
})
t.Run("Connection Flow Tests", func(t *testing.T) {
tests := []struct {
name string
test func(*testing.T)
desc string
}{
{
name: "Normal Handshake",
test: func(t *testing.T) {
t.Helper()
// Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
require.True(t, valid, "Data should be allowed after handshake")
},
},
{
name: "Normal Close",
test: func(t *testing.T) {
t.Helper()
// First establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "FIN should be allowed")
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
},
},
{
name: "RST During Connection",
test: func(t *testing.T) {
t.Helper()
// First establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets
// The connection will be cleaned up by timeout
},
},
{
name: "Simultaneous Close",
test: func(t *testing.T) {
t.Helper()
// First establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "Final ACKs should be allowed")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tt.test(t)
})
}
})
}
func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
tests := []struct {
name string
setupState func()
sendRST func()
wantValid bool
desc string
}{
{
name: "RST in established",
setupState: func() {
// Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
},
wantValid: true,
desc: "Should accept RST for established connection",
},
{
name: "RST without connection",
setupState: func() {},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
},
wantValid: false,
desc: "Should reject RST without connection",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupState()
tt.sendRST()
// Verify connection state is as expected
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
if tt.wantValid {
require.NotNil(t, conn)
require.Equal(t, TCPStateClosed, conn.State)
require.False(t, conn.IsEstablished())
}
})
}
}
// Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
}

View File

@@ -1,219 +0,0 @@
package conntrack
import (
"context"
"net/netip"
"sync"
"time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
// DefaultUDPTimeout is the default timeout for UDP connections
DefaultUDPTimeout = 30 * time.Second
// UDPCleanupInterval is how often we check for stale connections
UDPCleanupInterval = 15 * time.Second
)
// UDPConnTrack represents a UDP connection state
type UDPConnTrack struct {
BaseConnTrack
SourcePort uint16
DestPort uint16
}
// UDPTracker manages UDP connection states
type UDPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*UDPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
flowLogger nftypes.FlowLogger
}
// NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
if timeout == 0 {
timeout = DefaultUDPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &UDPTracker{
logger: logger,
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
tickerCancel: cancel,
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
return tracker
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
}
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists {
return
}
conn := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.UpdateLastSeen()
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) {
return false
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-ctx.Done():
return
}
}
}
func (t *UDPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() {
t.tickerCancel()
t.mutex.Lock()
t.connections = nil
t.mutex.Unlock()
}
// GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key]
return conn, exists
}
// Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration {
return t.timeout
}
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.UDP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -1,252 +0,0 @@
package conntrack
import (
"context"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewUDPTracker(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
wantTimeout time.Duration
}{
{
name: "with custom timeout",
timeout: 1 * time.Minute,
wantTimeout: 1 * time.Minute,
},
{
name: "with zero timeout uses default",
timeout: 0,
wantTimeout: DefaultUDPTimeout,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.tickerCancel)
})
}
}
func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
// Verify connection was tracked
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := tracker.connections[key]
require.True(t, exists)
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort)
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
}
func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
// Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tests := []struct {
name string
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
sleep time.Duration
want bool
}{
{
name: "valid inbound response",
srcIP: dstIP, // Original destination is now source
dstIP: srcIP, // Original source is now destination
srcPort: dstPort, // Original destination port is now source
dstPort: srcPort, // Original source port is now destination
sleep: 0,
want: true,
},
{
name: "invalid source IP",
srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
want: false,
},
{
name: "invalid destination IP",
srcIP: dstIP,
dstIP: netip.MustParseAddr("192.168.1.4"),
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
want: false,
},
{
name: "invalid source port",
srcIP: dstIP,
dstIP: srcIP,
srcPort: 54321,
dstPort: srcPort,
sleep: 0,
want: false,
},
{
name: "invalid destination port",
srcIP: dstIP,
dstIP: srcIP,
srcPort: dstPort,
dstPort: 54321,
sleep: 0,
want: false,
},
{
name: "expired connection",
srcIP: dstIP,
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
sleep: 2 * time.Second, // Longer than tracker timeout
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.sleep > 0 {
time.Sleep(tt.sleep)
}
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
assert.Equal(t, tt.want, got)
})
}
}
func TestUDPTracker_Cleanup(t *testing.T) {
// Use shorter intervals for testing
timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond
ctx, tickerCancel := context.WithCancel(context.Background())
defer tickerCancel()
// Create tracker with custom cleanup interval
tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval),
tickerCancel: tickerCancel,
logger: logger,
flowLogger: flowLogger,
}
// Start cleanup routine
go tracker.cleanupRoutine(ctx)
// Add some connections
connections := []struct {
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
}{
{
srcIP: netip.MustParseAddr("192.168.1.2"),
dstIP: netip.MustParseAddr("192.168.1.3"),
srcPort: 12345,
dstPort: 53,
},
{
srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: netip.MustParseAddr("192.168.1.5"),
srcPort: 12346,
dstPort: 53,
},
}
for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
}
// Verify initial connections
assert.Len(t, tracker.connections, 2)
// Wait for connection timeout and cleanup interval
time.Sleep(timeout + 2*cleanupInterval)
tracker.mutex.RLock()
connCount := len(tracker.connections)
tracker.mutex.RUnlock()
// Verify connections were cleaned up
assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up")
// Properly close the tracker
tracker.Close()
}
func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
}
})
}

View File

@@ -1,90 +0,0 @@
package forwarder
import (
"fmt"
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu uint32
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *endpoint) MTU() uint32 {
return e.mtu
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
}
func (e *endpoint) MaxHeaderLength() uint16 {
return 0
}
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
// Send the packet through WireGuard
address := netHeader.DestinationAddress()
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil {
e.logger.Error("CreateOutboundPacket: %v", err)
continue
}
written++
}
return written, nil
}
func (e *endpoint) Wait() {
// not required
}
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
// not required
}
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true
}
type epID stack.TransportEndpointID
func (i epID) String() string {
// src and remote is swapped
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
}

View File

@@ -1,169 +0,0 @@
package forwarder
import (
"context"
"fmt"
"net"
"runtime"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
defaultReceiveWindow = 32768
defaultMaxInFlight = 1024
iosReceiveWindow = 16384
iosMaxInFlight = 256
)
type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip net.IP
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
},
HandleLocal: false,
})
mtu, err := iface.GetDevice().MTU()
if err != nil {
return nil, fmt.Errorf("get MTU: %w", err)
}
nicID := tcpip.NICID(1)
endpoint := &endpoint{
logger: logger,
device: iface.GetWGDevice(),
mtu: uint32(mtu),
}
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("failed to create NIC: %v", err)
}
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
PrefixLen: ones,
},
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("failed to add protocol address: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
)
if err != nil {
return nil, fmt.Errorf("creating default subnet: %w", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
return nil, fmt.Errorf("set promiscuous mode: %s", err)
}
if err := s.SetSpoofing(nicID, true); err != nil {
return nil, fmt.Errorf("set spoofing: %s", err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: defaultSubnet,
NIC: nicID,
},
})
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: iface.Address().IP,
}
receiveWindow := defaultReceiveWindow
maxInFlight := defaultMaxInFlight
if runtime.GOOS == "ios" {
receiveWindow = iosReceiveWindow
maxInFlight = iosMaxInFlight
}
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
udpForwarder := udp.NewForwarder(s, f.handleUDP)
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil
}
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("packet too small: %d bytes", len(payload))
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
defer pkt.DecRef()
if f.endpoint.dispatcher != nil {
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
return nil
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() {
f.cancel()
if f.udpForwarder != nil {
f.udpForwarder.Stop()
}
f.stack.Close()
f.stack.Wait()
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) {
return net.IPv4(127, 0, 0, 1)
}
return addr.AsSlice()
}

View File

@@ -1,127 +0,0 @@
package forwarder
import (
"context"
"net"
"net/netip"
"time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
lc := net.ListenConfig{}
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
defer func() {
if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err)
}
}()
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return
}
response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err)
}
return
}
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + n),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+n)
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err)
return
}
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
}
// sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.ICMP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
ICMPType: icmpType,
ICMPCode: icmpCode,
// TODO: get packets/bytes
})
}

View File

@@ -1,132 +0,0 @@
package forwarder
import (
"context"
"fmt"
"io"
"net"
"net/netip"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// handleTCP is called by the TCP forwarder for new connections.
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
r.Complete(true)
return
}
// Complete the handshake
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep, flowID)
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}()
// Create context for managing the proxy goroutines
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
errChan := make(chan error, 2)
go func() {
_, err := io.Copy(outConn, inConn)
errChan <- err
}()
go func() {
_, err := io.Copy(inConn, outConn)
errChan <- err
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
return
}
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}

View File

@@ -1,332 +0,0 @@
package forwarder
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
udpTimeout = 30 * time.Second
)
type udpPacketConn struct {
conn *gonet.UDPConn
outConn net.Conn
lastSeen atomic.Int64
cancel context.CancelFunc
ep tcpip.Endpoint
flowID uuid.UUID
}
type udpForwarder struct {
sync.RWMutex
logger *nblog.Logger
flowLogger nftypes.FlowLogger
conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool
ctx context.Context
cancel context.CancelFunc
}
type idleConn struct {
id stack.TransportEndpointID
conn *udpPacketConn
}
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{
logger: logger,
flowLogger: flowLogger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx,
cancel: cancel,
bufPool: sync.Pool{
New: func() any {
b := make([]byte, mtu)
return &b
},
},
}
go f.cleanup()
return f
}
// Stop stops the UDP forwarder and all active connections
func (f *udpForwarder) Stop() {
f.cancel()
f.Lock()
defer f.Unlock()
for id, conn := range f.conns {
conn.cancel()
if err := conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
}
if err := conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
conn.ep.Close()
delete(f.conns, id)
}
}
// cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-f.ctx.Done():
return
case <-ticker.C:
var idleConns []idleConn
f.RLock()
for id, conn := range f.conns {
if conn.getIdleDuration() > udpTimeout {
idleConns = append(idleConns, idleConn{id, conn})
}
}
f.RUnlock()
for _, idle := range idleConns {
idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
}
if err := idle.conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
}
idle.conn.ep.Close()
f.Lock()
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
}
}
}
}
// handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet")
return
}
id := r.ID()
f.udpForwarder.RLock()
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
return
}
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
}
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn := &udpPacketConn{
conn: inConn,
outConn: outConn,
cancel: connCancel,
ep: ep,
flowID: flowID,
}
pConn.updateLastSeen()
f.udpForwarder.Lock()
// Double-check no connection was created while we were setting up
if _, exists := f.udpForwarder.conns[id]; exists {
f.udpForwarder.Unlock()
pConn.cancel()
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
}
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
defer func() {
pConn.cancel()
if err := pConn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
}
if err := pConn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
ep.Close()
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
}()
errChan := make(chan error, 2)
go func() {
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}()
go func() {
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
return
}
}
// sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}
func (c *udpPacketConn) updateLastSeen() {
c.lastSeen.Store(time.Now().UnixNano())
}
func (c *udpPacketConn) getIdleDuration() time.Duration {
lastSeen := time.Unix(0, c.lastSeen.Load())
return time.Since(lastSeen)
}
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp)
buffer := *bufp
for {
if ctx.Err() != nil {
return ctx.Err()
}
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
n, err := src.Read(buffer)
if err != nil {
if isTimeout(err) {
continue
}
return fmt.Errorf("read from %s: %w", direction, err)
}
_, err = dst.Write(buffer[:n])
if err != nil {
return fmt.Errorf("write to %s: %w", direction, err)
}
c.updateLastSeen()
}
}
func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
}
func isTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}

View File

@@ -1,131 +0,0 @@
package uspfilter
import (
"fmt"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
)
type localIPManager struct {
mu sync.RWMutex
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
ipv4Bitmap [1 << 16]uint32
}
func newLocalIPManager() *localIPManager {
return &localIPManager{}
}
func (m *localIPManager) setBitmapBit(ip net.IP) {
ipv4 := ip.To4()
if ipv4 == nil {
return
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
m.ipv4Bitmap[high] |= 1 << (low % 32)
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := (uint16(ip[0]) << 8) | uint16(ip[1])
low := (uint16(ip[2]) << 8) | uint16(ip[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
}
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
if ipv4 := ip.To4(); ipv4 != nil {
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
if int(high) >= len(*newIPv4Bitmap) {
return fmt.Errorf("invalid IPv4 address: %s", ip)
}
ipStr := ip.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
newIPv4Bitmap[high] |= 1 << (low % 32)
}
}
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
return
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
}
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %v", r)
}
}()
var newIPv4Bitmap [1 << 16]uint32
ipv4Set := make(map[string]struct{})
var ipv4Addresses []string
// 127.0.0.0/8
high := uint16(127) << 8
for i := uint16(0); i < 256; i++ {
newIPv4Bitmap[high|i] = 0xffffffff
}
if iface != nil {
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
return err
}
}
interfaces, err := net.Interfaces()
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}
}
m.mu.Lock()
m.ipv4Bitmap = newIPv4Bitmap
m.mu.Unlock()
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
return nil
}
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
m.mu.RLock()
defer m.mu.RUnlock()
if ip.Is4() {
return m.checkBitmapBit(ip.AsSlice())
}
return false
}

View File

@@ -1,271 +0,0 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestLocalIPManager(t *testing.T) {
tests := []struct {
name string
setupAddr wgaddr.Address
testIP netip.Addr
expected bool
}{
{
name: "Localhost range",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.2"),
expected: true,
},
{
name: "Localhost standard address",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.1"),
expected: true,
},
{
name: "Localhost range edge",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.255.255.255"),
expected: true,
},
{
name: "Local IP matches",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.1"),
expected: true,
},
{
name: "Local IP doesn't match",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.2"),
expected: false,
},
{
name: "IPv6 address",
setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
},
testIP: netip.MustParseAddr("fe80::1"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() wgaddr.Address {
return tt.setupAddr
},
}
err := manager.UpdateLocalIPs(mock)
require.NoError(t, err)
result := manager.IsLocalIP(tt.testIP)
require.Equal(t, tt.expected, result)
})
}
}
func TestLocalIPManager_AllInterfaces(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{}
// Get actual local interfaces
interfaces, err := net.Interfaces()
require.NoError(t, err)
var tests []struct {
ip string
expected bool
}
// Add all local interface IPs to test cases
for _, iface := range interfaces {
addrs, err := iface.Addrs()
require.NoError(t, err)
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if ip4 := ip.To4(); ip4 != nil {
tests = append(tests, struct {
ip string
expected bool
}{
ip: ip4.String(),
expected: true,
})
}
}
}
// Add some external IPs as negative test cases
externalIPs := []string{
"8.8.8.8",
"1.1.1.1",
"208.67.222.222",
}
for _, ip := range externalIPs {
tests = append(tests, struct {
ip string
expected bool
}{
ip: ip,
expected: false,
})
}
require.NotEmpty(t, tests, "No test cases generated")
err = manager.UpdateLocalIPs(mock)
require.NoError(t, err)
t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
})
}
}
// MapImplementation is a version using map[string]struct{}
type MapImplementation struct {
localIPs map[string]struct{}
}
func BenchmarkIPChecks(b *testing.B) {
interfaces := make([]net.IP, 16)
for i := range interfaces {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
}
// Setup bitmap version
bitmapManager := &localIPManager{
ipv4Bitmap: [1 << 16]uint32{},
}
for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip)
}
// Setup map version
mapManager := &MapImplementation{
localIPs: make(map[string]struct{}),
}
for _, ip := range interfaces[:8] {
mapManager.localIPs[ip.String()] = struct{}{}
}
b.Run("Bitmap_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Bitmap_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Map_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
}
})
b.Run("Map_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
}
})
}
func BenchmarkWGPosition(b *testing.B) {
wgIP := net.ParseIP("10.10.0.1")
// Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
bm.setBitmapBit(wgIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
b.Run("WG_Last", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
// Fill with other IPs first
for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
}
bm.setBitmapBit(wgIP) // Add WG IP last
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
}

View File

@@ -1,252 +0,0 @@
// Package log provides a high-performance, non-blocking logger for userspace networking
package log
import (
"context"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
)
const (
maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2
defaultFlushInterval = 2 * time.Second
logChannelSize = 1000
)
type Level uint32
const (
LevelPanic Level = iota
LevelFatal
LevelError
LevelWarn
LevelInfo
LevelDebug
LevelTrace
)
var levelStrings = map[Level]string{
LevelPanic: "PANC",
LevelFatal: "FATL",
LevelError: "ERRO",
LevelWarn: "WARN",
LevelInfo: "INFO",
LevelDebug: "DEBG",
LevelTrace: "TRAC",
}
type logMessage struct {
level Level
format string
args []any
}
// Logger is a high-performance, non-blocking logger
type Logger struct {
output io.Writer
level atomic.Uint32
msgChannel chan logMessage
shutdown chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
bufPool sync.Pool
}
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
msgChannel: make(chan logMessage, logChannelSize),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() any {
b := make([]byte, 0, maxMessageSize)
return &b
},
},
}
logrusLevel := logrusLogger.GetLevel()
l.level.Store(uint32(logrusLevel))
level := levelStrings[Level(logrusLevel)]
log.Debugf("New uspfilter logger created with loglevel %v", level)
l.wg.Add(1)
go l.worker()
return l
}
// SetLevel sets the logging level
func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level))
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
func (l *Logger) log(level Level, format string, args ...any) {
select {
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
default:
}
}
// Error logs a message at error level
func (l *Logger) Error(format string, args ...any) {
if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...)
}
}
// Warn logs a message at warning level
func (l *Logger) Warn(format string, args ...any) {
if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...)
}
}
// Info logs a message at info level
func (l *Logger) Info(format string, args ...any) {
if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...)
}
}
// Debug logs a message at debug level
func (l *Logger) Debug(format string, args ...any) {
if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...)
}
}
// Trace logs a message at trace level
func (l *Logger) Trace(format string, args ...any) {
if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...)
}
}
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
var msg string
if len(args) > 0 {
msg = fmt.Sprintf(format, args...)
} else {
msg = format
}
*buf = append(*buf, msg...)
*buf = append(*buf, '\n')
if len(*buf) > maxMessageSize {
*buf = (*buf)[:maxMessageSize]
}
}
// processMessage handles a single log message and adds it to the buffer
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
bufp := l.bufPool.Get().(*[]byte)
defer l.bufPool.Put(bufp)
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
if len(*buffer)+len(*bufp) > maxBatchSize {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
*buffer = append(*buffer, *bufp...)
}
// flushBuffer writes the accumulated buffer to output
func (l *Logger) flushBuffer(buffer *[]byte) {
if len(*buffer) > 0 {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
}
// processBatch processes as many messages as possible without blocking
func (l *Logger) processBatch(buffer *[]byte) {
for len(*buffer) < maxBatchSize {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
default:
return
}
}
}
// handleShutdown manages the graceful shutdown sequence with timeout
func (l *Logger) handleShutdown(buffer *[]byte) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
for {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
case <-ctx.Done():
l.flushBuffer(buffer)
return
}
if len(l.msgChannel) == 0 {
l.flushBuffer(buffer)
return
}
}
}
// worker is the main goroutine that processes log messages
func (l *Logger) worker() {
defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop()
buffer := make([]byte, 0, maxBatchSize)
for {
select {
case <-l.shutdown:
l.handleShutdown(&buffer)
return
case <-ticker.C:
l.flushBuffer(&buffer)
case msg := <-l.msgChannel:
l.processMessage(msg, &buffer)
l.processBatch(&buffer)
}
}
}
// Stop gracefully shuts down the logger
func (l *Logger) Stop(ctx context.Context) error {
done := make(chan struct{})
l.closeOnce.Do(func() {
close(l.shutdown)
})
go func() {
l.wg.Wait()
close(done)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}

View File

@@ -1,121 +0,0 @@
package log_test
import (
"context"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
type discard struct{}
func (d *discard) Write(p []byte) (n int, err error) {
return len(p), nil
}
func BenchmarkLogger(b *testing.B) {
simpleMessage := "Connection established"
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4 // TCPStateEstablished
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
protocol := "TCP"
direction := "outbound"
flags := uint16(0x18) // ACK + PSH
sequence := uint32(123456789)
acknowledged := uint32(987654321)
payloadSize := 1460
fragmented := false
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
b.Run("SimpleMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(simpleMessage)
}
})
b.Run("ConntrackMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
b.Run("ComplexMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
}
})
}
// BenchmarkLoggerParallel tests the logger under concurrent load
func BenchmarkLoggerParallel(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
}
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
func BenchmarkLoggerBurst(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
}
}
func createTestLogger() *log.Logger {
logrusLogger := logrus.New()
logrusLogger.SetOutput(&discard{})
logrusLogger.SetLevel(logrus.TraceLevel)
return log.NewFromLogrus(logrusLogger)
}
func cleanupLogger(logger *log.Logger) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = logger.Stop(ctx)
}

View File

@@ -1,45 +1,30 @@
package uspfilter
import (
"net/netip"
"net"
"github.com/google/gopacket"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
// PeerRule to handle management of rules
type PeerRule struct {
// Rule to handle management of rules
type Rule struct {
id string
mgmtId []byte
ip netip.Addr
ip net.IP
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
sPort uint16
dPort uint16
drop bool
comment string
udpHook func([]byte) bool
}
// ID returns the rule id
func (r *PeerRule) ID() string {
return r.id
}
type RouteRule struct {
id string
mgmtId []byte
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
}
// ID returns the rule id
func (r *RouteRule) ID() string {
// GetRuleID returns the rule id
func (r *Rule) GetRuleID() string {
return r.id
}

Some files were not shown because too many files have changed in this diff Show More