Compare commits

..

19 Commits

Author SHA1 Message Date
Zoltán Papp
c495eaa549 Move interface near to engine 2024-10-10 14:46:51 +02:00
Zoltán Papp
b8026ad541 Merge branch 'main' into relay/fix/wg-roaming 2024-10-09 23:24:07 +02:00
Zoltán Papp
a5deeda727 Revert force install change 2024-10-09 19:20:20 +02:00
Zoltán Papp
5b2d5f8df1 Try to force install libpcap 2024-10-09 19:12:33 +02:00
Zoltán Papp
6369706ade Merge branch 'main' into relay/fix/wg-roaming 2024-10-09 18:54:30 +02:00
Zoltan Papp
e3dfbe5acf Add trace log 2024-10-09 14:07:35 +02:00
Zoltan Papp
deeb05047d Handle addr resolve error 2024-10-09 14:05:43 +02:00
Zoltan Papp
1814b07a4b Replace error check to errors.Is 2024-10-09 14:02:23 +02:00
Zoltán Papp
b04d19bb0a Fix nil pointer in error handling 2024-10-08 12:04:57 +02:00
Zoltán Papp
20815c9f90 Remove unused function 2024-10-07 13:28:21 +02:00
Zoltán Papp
ba3cdb30ee Remove unnecessary ctx cancel check 2024-10-07 13:05:11 +02:00
Zoltán Papp
1f25bb0751 Reducate cognitive complexity 2024-10-07 12:58:45 +02:00
Zoltán Papp
9e7aac3a56 Reducate cognitive complexity 2024-10-07 12:52:55 +02:00
Zoltán Papp
718d9526a7 Fix test 2024-10-07 12:45:21 +02:00
Zoltán Papp
48184ecf21 Fix eBPF pause handling 2024-10-07 12:40:53 +02:00
Zoltán Papp
f18ae8b925 Apply pause logic 2024-10-07 11:22:48 +02:00
Zoltán Papp
90d9dd4c08 Remove unused function from eBPF proxy 2024-10-07 10:35:53 +02:00
Zoltán Papp
acad98e328 Code cleaning 2024-10-03 02:29:46 +02:00
Zoltán Papp
9d75cc3273 Add pause function for proxies 2024-10-03 01:24:05 +02:00
398 changed files with 11840 additions and 37294 deletions

3
.github/FUNDING.yml vendored
View File

@@ -1,3 +0,0 @@
# These are supported funding model platforms
github: [netbirdio]

View File

@@ -21,7 +21,6 @@ jobs:
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
@@ -29,9 +28,8 @@ jobs:
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}
key: macos-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
macos-gotest-
macos-go-
- name: Install libpcap
@@ -44,4 +42,4 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...

View File

@@ -11,135 +11,8 @@ concurrency:
cancel-in-progress: true
jobs:
build-cache:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
id: cache
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: steps.cache.outputs.cache-hit != 'true'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Build client
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 go build .
- name: Build client 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 .
- name: Build management
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 go build .
- name: Build management 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 .
- name: Build signal
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 go build .
- name: Build signal 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 .
- name: Build relay
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 go build .
- name: Build relay 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
test:
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
test_management:
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']
@@ -149,26 +22,19 @@ jobs:
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -183,85 +49,27 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
benchmark:
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
test_client_on_docker:
needs: [ build-cache ]
runs-on: ubuntu-20.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -271,6 +79,9 @@ jobs:
- name: check git status
run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/
- name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
@@ -287,7 +98,7 @@ jobs:
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
- run: chmod +x *testing.bin
@@ -295,7 +106,7 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -24,23 +24,6 @@ jobs:
id: go
with:
go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-
${{ runner.os }}-go-
- name: Download wintun
uses: carlosperate/download-file-action@v2
@@ -59,13 +42,11 @@ jobs:
- run: choco install -y sysinternals --ignore-checksums
- run: choco install -y mingw
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd
ignore_words_list: erro,clienta,hastable,iif
skip: go.mod,go.sum
only_warn: 1
golangci:
@@ -46,7 +46,7 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v4
uses: golangci/golangci-lint-action@v3
with:
version: latest
args: --timeout=12m --out-format colored-line-number
args: --timeout=12m

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.17"
SIGN_PIPE_VER: "v0.0.14"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
@@ -223,4 +223,4 @@ jobs:
repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
inputs: '{ "tag": "${{ github.ref }}" }'

View File

@@ -96,9 +96,6 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
archives:
- builds:
- netbird

View File

@@ -23,9 +23,6 @@ builds:
tags:
- load_wgnt_from_rsrc
universal_binaries:
- id: netbird-ui-darwin
archives:
- builds:
- netbird-ui-darwin

View File

@@ -17,12 +17,8 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a>
</p>
</div>
@@ -34,7 +30,7 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a>
<br/>
</strong>

View File

@@ -1,4 +1,4 @@
FROM alpine:3.21.0
FROM alpine:3.20
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -12,8 +12,6 @@ import (
"strings"
)
const anonTLD = ".domain"
type Anonymizer struct {
ipAnonymizer map[netip.Addr]netip.Addr
domainAnonymizer map[string]string
@@ -21,8 +19,6 @@ type Anonymizer struct {
currentAnonIPv6 netip.Addr
startAnonIPv4 netip.Addr
startAnonIPv6 netip.Addr
domainKeyRegex *regexp.Regexp
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
@@ -38,8 +34,6 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
currentAnonIPv6: startIPv6,
startAnonIPv4: startIPv4,
startAnonIPv6: startIPv6,
domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`),
}
}
@@ -89,39 +83,29 @@ func (a *Anonymizer) AnonymizeIPString(ip string) string {
}
func (a *Anonymizer) AnonymizeDomain(domain string) string {
baseDomain := domain
hasDot := strings.HasSuffix(domain, ".")
if hasDot {
baseDomain = domain[:len(domain)-1]
}
if strings.HasSuffix(baseDomain, "netbird.io") ||
strings.HasSuffix(baseDomain, "netbird.selfhosted") ||
strings.HasSuffix(baseDomain, "netbird.cloud") ||
strings.HasSuffix(baseDomain, "netbird.stage") ||
strings.HasSuffix(baseDomain, anonTLD) {
if strings.HasSuffix(domain, "netbird.io") ||
strings.HasSuffix(domain, "netbird.selfhosted") ||
strings.HasSuffix(domain, "netbird.cloud") ||
strings.HasSuffix(domain, "netbird.stage") ||
strings.HasSuffix(domain, ".domain") {
return domain
}
parts := strings.Split(baseDomain, ".")
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
baseForLookup := parts[len(parts)-2] + "." + parts[len(parts)-1]
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
anonymized, ok := a.domainAnonymizer[baseForLookup]
anonymized, ok := a.domainAnonymizer[baseDomain]
if !ok {
anonymizedBase := "anon-" + generateRandomString(5) + anonTLD
a.domainAnonymizer[baseForLookup] = anonymizedBase
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
a.domainAnonymizer[baseDomain] = anonymizedBase
anonymized = anonymizedBase
}
result := strings.Replace(baseDomain, baseForLookup, anonymized, 1)
if hasDot {
result += "."
}
return result
return strings.Replace(domain, baseDomain, anonymized, 1)
}
func (a *Anonymizer) AnonymizeURI(uri string) string {
@@ -168,22 +152,27 @@ func (a *Anonymizer) AnonymizeString(str string) string {
return str
}
// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes.
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`)
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
}
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.SplitN(match, "=", 2)
domainPattern := `dns\.Question{Name:"([^"]+)",`
domainRegex := regexp.MustCompile(domainPattern)
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.Split(match, `"`)
if len(parts) >= 2 {
domain := parts[1]
if strings.HasSuffix(domain, anonTLD) {
if strings.HasSuffix(domain, ".domain") {
return match
}
return "domain=" + a.AnonymizeDomain(domain)
randomDomain := generateRandomString(10) + ".domain"
return strings.Replace(match, domain, randomDomain, 1)
}
return match
})
@@ -212,8 +201,6 @@ func isWellKnown(addr netip.Addr) bool {
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
"128.0.0.0", "8000::", // 2nd split subnet for default routes
}
if slices.Contains(wellKnown, addr.String()) {

View File

@@ -46,59 +46,11 @@ func TestAnonymizeIP(t *testing.T) {
func TestAnonymizeDNSLogLine(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
input string
original string
expect string
}{
{
name: "Basic domain with trailing content",
input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123",
original: "example.com",
expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`,
},
{
name: "Domain with trailing dot",
input: "domain=example.com. processing request with status=pending",
original: "example.com",
expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`,
},
{
name: "Multiple domains in log",
input: "forward domain=first.com status=ok, redirect to domain=second.com port=443",
original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately
expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`,
},
{
name: "Already anonymized domain",
input: "got request domain=anon-xyz123.domain from=client1 to=server2",
original: "", // nothing should be anonymized
expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`,
},
{
name: "Subdomain with trailing dot",
input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp",
original: "example.com",
expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`,
},
{
name: "Handler chain pattern log",
input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100",
original: "example.com",
expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`,
},
}
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeDNSLogLine(tc.input)
if tc.original != "" {
assert.NotContains(t, result, tc.original)
}
assert.Regexp(t, tc.expect, result)
})
}
result := anonymizer.AnonymizeDNSLogLine(testLog)
require.NotEqual(t, testLog, result)
assert.NotContains(t, result, "example.com")
}
func TestAnonymizeDomain(t *testing.T) {
@@ -115,36 +67,18 @@ func TestAnonymizeDomain(t *testing.T) {
`^anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Domain with Trailing Dot",
"example.com.",
`^anon-[a-zA-Z0-9]+\.domain.$`,
true,
},
{
"Subdomain",
"sub.example.com",
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Subdomain with Trailing Dot",
"sub.example.com.",
`^sub\.anon-[a-zA-Z0-9]+\.domain.$`,
true,
},
{
"Protected Domain",
"netbird.io",
`^netbird\.io$`,
false,
},
{
"Protected Domain with Trailing Dot",
"netbird.io.",
`^netbird\.io.$`,
false,
},
}
for _, tc := range tests {
@@ -206,16 +140,8 @@ func TestAnonymizeSchemeURI(t *testing.T) {
expect string
}{
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
{"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`},
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
{"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`},
{"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`},
{"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
{"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`},
{"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`},
{"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`},
{"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`},
}
for _, tc := range tests {

View File

@@ -3,7 +3,6 @@ package cmd
import (
"context"
"fmt"
"strings"
"time"
log "github.com/sirupsen/logrus"
@@ -62,15 +61,6 @@ var forCmd = &cobra.Command{
RunE: runForDuration,
}
var persistenceCmd = &cobra.Command{
Use: "persistence [on|off]",
Short: "Set network map memory persistence",
Long: `Configure whether the latest network map should persist in memory. When enabled, the last known network map will be kept in memory.`,
Example: " netbird debug persistence on",
Args: cobra.ExactArgs(1),
RunE: setNetworkMapPersistence,
}
func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
@@ -181,13 +171,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(1 * time.Second)
// Enable network map persistence before bringing the service up
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("failed to enable network map persistence: %v", status.Convert(err).Message())
}
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
@@ -217,13 +200,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
// Disable network map persistence after creating the debug bundle
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: false,
}); err != nil {
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
}
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
@@ -243,34 +219,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return nil
}
func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
persistence := strings.ToLower(args[0])
if persistence != "on" && persistence != "off" {
return fmt.Errorf("invalid persistence value: %s. Use 'on' or 'off'", args[0])
}
client := proto.NewDaemonServiceClient(conn)
_, err = client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: persistence == "on",
})
if err != nil {
return fmt.Errorf("failed to set network map persistence: %v", status.Convert(err).Message())
}
cmd.Printf("Network map persistence set to: %s\n", persistence)
return nil
}
func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())

View File

@@ -1,173 +0,0 @@
package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var networksCMD = &cobra.Command{
Use: "networks",
Aliases: []string{"routes"},
Short: "Manage networks",
Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List networks",
Example: " netbird networks list",
Long: "List all available network routes.",
RunE: networksList,
}
var routesSelectCmd = &cobra.Command{
Use: "select network...|all",
Short: "Select network",
Long: "Select a list of networks by identifiers or 'all' to clear all selections and to accept all (including new) networks.\nDefault mode is replace, use -a to append to already selected networks.",
Example: " netbird networks select all\n netbird networks select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: networksSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect network...|all",
Short: "Deselect networks",
Long: "Deselect previously selected networks by identifiers or 'all' to disable accepting any networks.",
Example: " netbird networks deselect all\n netbird networks deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: networksDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing")
}
func networksList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListNetworks(cmd.Context(), &proto.ListNetworksRequest{})
if err != nil {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No networks available.")
return nil
}
printNetworks(cmd, resp)
return nil
}
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
cmd.Println("Available Networks:")
for _, route := range resp.Routes {
printNetwork(cmd, route)
}
}
func printNetwork(cmd *cobra.Command, route *proto.Network) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Network) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Network, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for resolvedDomain, ipList := range resolvedIPs {
cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", "))
}
}
func networksSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectNetworksRequest{
NetworkIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectNetworks(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message())
}
cmd.Println("Networks selected successfully.")
return nil
}
func networksDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectNetworksRequest{
NetworkIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectNetworks(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message())
}
cmd.Println("Networks deselected successfully.")
return nil
}

View File

@@ -1,33 +0,0 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

View File

@@ -142,20 +142,19 @@ func init() {
rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(routesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
routesCmd.AddCommand(routesListCmd)
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+

174
client/cmd/route.go Normal file
View File

@@ -0,0 +1,174 @@
package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var routesCmd = &cobra.Command{
Use: "routes",
Short: "Manage network routes",
Long: `Commands to list, select, or deselect network routes.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List routes",
Example: " netbird routes list",
Long: "List all available network routes.",
RunE: routesList,
}
var routesSelectCmd = &cobra.Command{
Use: "select route...|all",
Short: "Select routes",
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: routesSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect route...|all",
Short: "Deselect routes",
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: routesDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
}
func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
if err != nil {
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No routes available.")
return nil
}
printRoutes(cmd, resp)
return nil
}
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
printRoute(cmd, route)
}
}
func printRoute(cmd *cobra.Command, route *proto.Route) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Route) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for _, domain := range domains {
if ipList, exists := resolvedIPs[domain]; exists {
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
}
}
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes selected successfully.")
return nil
}
func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes deselected successfully.")
return nil
}

View File

@@ -2,7 +2,6 @@ package cmd
import (
"context"
"sync"
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
@@ -14,11 +13,10 @@ import (
)
type program struct {
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
serverInstance *server.Server
serverInstanceMu sync.Mutex
ctx context.Context
cancel context.CancelFunc
serv *grpc.Server
serverInstance *server.Server
}
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@@ -61,9 +61,7 @@ func (p *program) Start(svc service.Service) error {
}
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstanceMu.Lock()
p.serverInstance = serverInstance
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil {
@@ -74,7 +72,6 @@ func (p *program) Start(svc service.Service) error {
}
func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil {
in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in)
@@ -82,7 +79,6 @@ func (p *program) Stop(srv service.Service) error {
log.Errorf("failed to stop daemon: %v", err)
}
}
p.serverInstanceMu.Unlock()
p.cancel()

View File

@@ -1,181 +0,0 @@
package cmd
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var (
allFlag bool
)
var stateCmd = &cobra.Command{
Use: "state",
Short: "Manage daemon state",
Long: "Provides commands for managing and inspecting the Netbird daemon state.",
}
var stateListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List all stored states",
Long: "Lists all registered states with their status and basic information.",
Example: " netbird state list",
RunE: stateList,
}
var stateCleanCmd = &cobra.Command{
Use: "clean [state-name]",
Short: "Clean stored states",
Long: `Clean specific state or all states. The daemon must not be running.
This will perform cleanup operations and remove the state.`,
Example: ` netbird state clean dns_state
netbird state clean --all`,
RunE: stateClean,
PreRunE: func(cmd *cobra.Command, args []string) error {
// Check mutual exclusivity between --all flag and state-name argument
if allFlag && len(args) > 0 {
return fmt.Errorf("cannot specify both --all flag and state name")
}
if !allFlag && len(args) != 1 {
return fmt.Errorf("requires a state name argument or --all flag")
}
return nil
},
}
var stateDeleteCmd = &cobra.Command{
Use: "delete [state-name]",
Short: "Delete stored states",
Long: `Delete specific state or all states from storage. The daemon must not be running.
This will remove the state without performing any cleanup operations.`,
Example: ` netbird state delete dns_state
netbird state delete --all`,
RunE: stateDelete,
PreRunE: func(cmd *cobra.Command, args []string) error {
// Check mutual exclusivity between --all flag and state-name argument
if allFlag && len(args) > 0 {
return fmt.Errorf("cannot specify both --all flag and state name")
}
if !allFlag && len(args) != 1 {
return fmt.Errorf("requires a state name argument or --all flag")
}
return nil
},
}
func init() {
rootCmd.AddCommand(stateCmd)
stateCmd.AddCommand(stateListCmd, stateCleanCmd, stateDeleteCmd)
stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states")
stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states")
}
func stateList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListStates(cmd.Context(), &proto.ListStatesRequest{})
if err != nil {
return fmt.Errorf("failed to list states: %v", status.Convert(err).Message())
}
cmd.Printf("\nStored states:\n\n")
for _, state := range resp.States {
cmd.Printf("- %s\n", state.Name)
}
return nil
}
func stateClean(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {
stateName = args[0]
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.CleanState(cmd.Context(), &proto.CleanStateRequest{
StateName: stateName,
All: allFlag,
})
if err != nil {
return fmt.Errorf("failed to clean state: %v", status.Convert(err).Message())
}
if resp.CleanedStates == 0 {
cmd.Println("No states were cleaned")
return nil
}
if allFlag {
cmd.Printf("Successfully cleaned %d states\n", resp.CleanedStates)
} else {
cmd.Printf("Successfully cleaned state %q\n", stateName)
}
return nil
}
func stateDelete(cmd *cobra.Command, args []string) error {
var stateName string
if !allFlag {
stateName = args[0]
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DeleteState(cmd.Context(), &proto.DeleteStateRequest{
StateName: stateName,
All: allFlag,
})
if err != nil {
return fmt.Errorf("failed to delete state: %v", status.Convert(err).Message())
}
if resp.DeletedStates == 0 {
cmd.Println("No states were deleted")
return nil
}
if allFlag {
cmd.Printf("Successfully deleted %d states\n", resp.DeletedStates)
} else {
cmd.Printf("Successfully deleted state %q\n", stateName)
}
return nil
}

View File

@@ -40,7 +40,6 @@ type peerStateDetailOutput struct {
Latency time.Duration `json:"latency" yaml:"latency"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
}
type peersStateOutput struct {
@@ -99,7 +98,6 @@ type statusOutputOverview struct {
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
}
@@ -284,8 +282,7 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
}
@@ -393,8 +390,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
TransferSent: transferSent,
Latency: pbPeerState.GetLatency().AsDuration(),
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
Routes: pbPeerState.GetNetworks(),
Networks: pbPeerState.GetNetworks(),
Routes: pbPeerState.GetRoutes(),
}
peersStateDetail = append(peersStateDetail, peerState)
@@ -495,10 +491,10 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
}
networks := "-"
if len(overview.Networks) > 0 {
sort.Strings(overview.Networks)
networks = strings.Join(overview.Networks, ", ")
routes := "-"
if len(overview.Routes) > 0 {
sort.Strings(overview.Routes)
routes = strings.Join(overview.Routes, ", ")
}
var dnsServersString string
@@ -560,7 +556,6 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
"Interface type: %s\n"+
"Quantum resistance: %s\n"+
"Routes: %s\n"+
"Networks: %s\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
@@ -573,8 +568,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
networks,
networks,
routes,
peersCountString,
)
return summary
@@ -637,10 +631,10 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
}
}
networks := "-"
if len(peerState.Networks) > 0 {
sort.Strings(peerState.Networks)
networks = strings.Join(peerState.Networks, ", ")
routes := "-"
if len(peerState.Routes) > 0 {
sort.Strings(peerState.Routes)
routes = strings.Join(peerState.Routes, ", ")
}
peerString := fmt.Sprintf(
@@ -658,7 +652,6 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
" Transfer status (received/sent) %s/%s\n"+
" Quantum resistance: %s\n"+
" Routes: %s\n"+
" Networks: %s\n"+
" Latency: %s\n",
peerState.FQDN,
peerState.IP,
@@ -675,8 +668,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent),
rosenpassEnabledStatus,
networks,
networks,
routes,
peerState.Latency.String(),
)
@@ -688,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false
ipEval := false
nameEval := true
nameEval := false
if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter)
@@ -708,13 +700,11 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap {
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = false
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = true
break
}
}
} else {
nameEval = false
}
return statusEval || ipEval || nameEval
@@ -818,14 +808,6 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
@@ -866,10 +848,6 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
}
}
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range overview.Routes {
overview.Routes[i] = a.AnonymizeRoute(route)
}

View File

@@ -44,7 +44,7 @@ var resp = &proto.StatusResponse{
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
BytesRx: 200,
BytesTx: 100,
Networks: []string{
Routes: []string{
"10.1.0.0/24",
},
Latency: durationpb.New(time.Duration(10000000)),
@@ -93,7 +93,7 @@ var resp = &proto.StatusResponse{
PubKey: "Some-Pub-Key",
KernelInterface: true,
Fqdn: "some-localhost.awesome-domain.com",
Networks: []string{
Routes: []string{
"10.10.0.0/24",
},
},
@@ -149,9 +149,6 @@ var overview = statusOutputOverview{
Routes: []string{
"10.1.0.0/24",
},
Networks: []string{
"10.1.0.0/24",
},
Latency: time.Duration(10000000),
},
{
@@ -233,9 +230,6 @@ var overview = statusOutputOverview{
Routes: []string{
"10.10.0.0/24",
},
Networks: []string{
"10.10.0.0/24",
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
@@ -301,9 +295,6 @@ func TestParsingToJSON(t *testing.T) {
"quantumResistance": false,
"routes": [
"10.1.0.0/24"
],
"networks": [
"10.1.0.0/24"
]
},
{
@@ -327,8 +318,7 @@ func TestParsingToJSON(t *testing.T) {
"transferSent": 1000,
"latency": 10000000,
"quantumResistance": false,
"routes": null,
"networks": null
"routes": null
}
]
},
@@ -369,9 +359,6 @@ func TestParsingToJSON(t *testing.T) {
"routes": [
"10.10.0.0/24"
],
"networks": [
"10.10.0.0/24"
],
"dnsServers": [
{
"servers": [
@@ -431,8 +418,6 @@ func TestParsingToYAML(t *testing.T) {
quantumResistance: false
routes:
- 10.1.0.0/24
networks:
- 10.1.0.0/24
- fqdn: peer-2.awesome-domain.com
netbirdIp: 192.168.178.102
publicKey: Pubkey2
@@ -452,7 +437,6 @@ func TestParsingToYAML(t *testing.T) {
latency: 10ms
quantumResistance: false
routes: []
networks: []
cliVersion: development
daemonVersion: 0.14.1
management:
@@ -481,8 +465,6 @@ quantumResistance: false
quantumResistancePermissive: false
routes:
- 10.10.0.0/24
networks:
- 10.10.0.0/24
dnsServers:
- servers:
- 8.8.8.8:53
@@ -527,7 +509,6 @@ func TestParsingToDetail(t *testing.T) {
Transfer status (received/sent) 200 B/100 B
Quantum resistance: false
Routes: 10.1.0.0/24
Networks: 10.1.0.0/24
Latency: 10ms
peer-2.awesome-domain.com:
@@ -544,7 +525,6 @@ func TestParsingToDetail(t *testing.T) {
Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false
Routes: -
Networks: -
Latency: 10ms
OS: %s/%s
@@ -563,7 +543,6 @@ NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
@@ -585,7 +564,6 @@ NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`

View File

@@ -10,8 +10,6 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
@@ -40,7 +38,7 @@ func startTestingServices(t *testing.T) string {
signalAddr := signalLis.Addr().String()
config.Signal.URI = signalAddr
_, mgmLis := startManagement(t, config, "../testdata/store.sql")
_, mgmLis := startManagement(t, config, "../testdata/store.sqlite")
mgmAddr := mgmLis.Addr().String()
return mgmAddr
}
@@ -73,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir())
if err != nil {
t.Fatal(err)
}
@@ -95,7 +93,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -3,6 +3,7 @@
package firewall
import (
"context"
"fmt"
"runtime"
@@ -10,11 +11,10 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}

View File

@@ -3,7 +3,7 @@
package firewall
import (
"errors"
"context"
"fmt"
"os"
@@ -15,7 +15,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
@@ -33,65 +32,54 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager)
var fm firewall.Manager
var errFw error
if !iface.IsUserspaceBind() {
return fm, err
}
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
}
if err = fm.Init(stateManager); err != nil {
return nil, fmt.Errorf("init firewall: %s", err)
}
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
switch check() {
case IPTABLES:
log.Info("creating an iptables firewall manager")
return nbiptables.Create(iface)
fm, errFw = nbiptables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
}
case NFTABLES:
log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface)
fm, errFw = nbnftables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw)
}
default:
errFw = fmt.Errorf("no firewall manager found")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found")
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
}
if errUsp != nil {
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
if iface.IsUserspaceBind() {
var errUsp error
if errFw == nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
}
if errUsp != nil {
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
return nil, errUsp
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
if errFw != nil {
return nil, errFw
}
return fm, nil
}

View File

@@ -11,8 +11,6 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -23,23 +21,13 @@ const (
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
)
type aclEntries map[string][][]string
type entry struct {
spec []string
position int
}
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routingFwChainName string
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
stateManager *statemanager.Manager
entries map[string][][]string
ipsetStore *ipsetStore
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
@@ -48,35 +36,27 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
wgIface: wgIface,
routingFwChainName: routingFwChainName,
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
entries: make(map[string][][]string),
ipsetStore: newIpsetStore(),
}
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("failed to init ipset: %w", err)
}
return m, nil
}
func (m *aclManager) init(stateManager *statemanager.Manager) error {
m.stateManager = stateManager
m.seedInitialEntries()
m.seedInitialOptionalEntries()
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
err = m.cleanChains()
if err != nil {
return nil, err
}
if err := m.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
err = m.createDefaultChains()
if err != nil {
return nil, err
}
m.updateState()
return nil
return m, nil
}
func (m *aclManager) AddPeerFiltering(
@@ -157,8 +137,6 @@ func (m *aclManager) AddPeerFiltering(
chain: chain,
}
m.updateState()
return []firewall.Rule{rule}, nil
}
@@ -193,23 +171,15 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
}
}
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
err := m.iptablesClient.Delete(tableName, r.chain, r.specs...)
if err != nil {
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
}
m.updateState()
return nil
return err
}
func (m *aclManager) Reset() error {
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
m.updateState()
return nil
return m.cleanChains()
}
// todo write less destructive cleanup mechanism
@@ -262,19 +232,6 @@ func (m *aclManager) cleanChains() error {
}
}
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
@@ -310,17 +267,6 @@ func (m *aclManager) createDefaultChains() error {
}
}
for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)
return nil
}
@@ -332,63 +278,27 @@ func (m *aclManager) createDefaultChains() error {
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished()
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
}
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
position: 2,
},
}
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1,
},
}
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}
func (m *aclManager) updateState() {
if m.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := m.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,

View File

@@ -8,13 +8,10 @@ import (
"sync"
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Manager of iptables firewall
@@ -36,10 +33,10 @@ type iFaceMapper interface {
}
// Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("init iptables: %w", err)
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
}
m := &Manager{
@@ -47,51 +44,20 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient,
}
m.router, err = newRouter(iptablesClient, wgIface)
m.router, err = newRouter(context, iptablesClient, wgIface)
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
log.Debugf("failed to initialize route related chains: %s", err)
return nil, err
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
log.Debugf("failed to initialize ACL manager: %s", err)
return nil, err
}
return m, nil
}
func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}
stateManager.RegisterState(state)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err)
}
if err := m.router.init(stateManager); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
return nil
}
// AddPeerFiltering adds a rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
@@ -112,7 +78,7 @@ func (m *Manager) AddPeerFiltering(
}
func (m *Manager) AddRouteFiltering(
sources []netip.Prefix,
sources [] netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
@@ -167,27 +133,20 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
}
// Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
var merr *multierror.Error
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
errAcl := m.aclMgr.Reset()
if errAcl != nil {
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
}
if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
errMgr := m.router.Reset()
if errMgr != nil {
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
return errMgr
}
// attempt to delete state only if all other operations succeeded
if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
return errAcl
}
// AllowNetbird allows netbird interface traffic
@@ -207,9 +166,19 @@ func (m *Manager) AllowNetbird() error {
"",
)
if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err)
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
return nil
_, err = m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionOUT,
firewall.ActionAccept,
"",
"",
)
return err
}
// Flush doesn't need to be implemented for this manager

View File

@@ -1,6 +1,7 @@
package iptables
import (
"context"
"fmt"
"net"
"testing"
@@ -55,14 +56,13 @@ func TestIptablesManager(t *testing.T) {
require.NoError(t, err)
// just check on the local interface
manager, err := Create(ifaceMock)
manager, err := Create(context.Background(), ifaceMock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second)
defer func() {
err := manager.Reset(nil)
err := manager.Reset()
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
@@ -122,7 +122,7 @@ func TestIptablesManager(t *testing.T) {
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil)
err = manager.Reset()
require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
@@ -154,14 +154,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second)
defer func() {
err := manager.Reset(nil)
err := manager.Reset()
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)
@@ -220,7 +219,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
})
t.Run("reset check", func(t *testing.T) {
err = manager.Reset(nil)
err = manager.Reset()
require.NoError(t, err, "failed to reset")
})
}
@@ -252,13 +251,12 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second)
defer func() {
err := manager.Reset(nil)
err := manager.Reset()
require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second)

View File

@@ -3,6 +3,7 @@
package iptables
import (
"context"
"fmt"
"net/netip"
"strconv"
@@ -17,25 +18,22 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
ipv4Nat = "netbird-rt-nat"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
chainRTPRE = "NETBIRD-RT-PRE"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set"
)
@@ -50,31 +48,28 @@ type routeFilteringRuleParams struct {
SetName string
}
type routeRules map[string][]string
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
type router struct {
ctx context.Context
stop context.CancelFunc
iptablesClient *iptables.IPTables
rules routeRules
ipsetCounter *ipsetCounter
rules map[string][]string
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
wgIface iFaceMapper
legacyManagement bool
stateManager *statemanager.Manager
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
iptablesClient: iptablesClient,
rules: make(map[string][]string),
wgIface: wgIface,
}
r.ipsetCounter = refcounter.New(
func(name string, sources []netip.Prefix) (struct{}, error) {
return struct{}{}, r.createIpSet(name, sources)
},
r.createIpSet,
func(name string, _ struct{}) error {
return r.deleteIpSet(name)
},
@@ -84,23 +79,16 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
return nil, fmt.Errorf("init ipset: %w", err)
}
return r, nil
}
func (r *router) init(stateManager *statemanager.Manager) error {
r.stateManager = stateManager
if err := r.cleanUpDefaultForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("cleanup routing rules: %s", err)
return nil, err
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
err = r.createContainers()
if err != nil {
log.Errorf("create containers for route: %s", err)
}
r.updateState()
return nil
return r, err
}
func (r *router) AddRouteFiltering(
@@ -141,8 +129,6 @@ func (r *router) AddRouteFiltering(
r.rules[string(ruleKey)] = rule
r.updateState()
return ruleKey, nil
}
@@ -166,8 +152,6 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
log.Debugf("route rule %s not found", ruleKey)
}
r.updateState()
return nil
}
@@ -180,18 +164,18 @@ func (r *router) findSetNameInRule(rule []string) string {
return ""
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return fmt.Errorf("create set %s: %w", setName, err)
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := ipset.AddPrefix(setName, prefix); err != nil {
return fmt.Errorf("add element to set %s: %w", setName, err)
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
}
}
return nil
return struct{}{}, nil
}
func (r *router) deleteIpSet(setName string) error {
@@ -222,8 +206,6 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("add inverse nat rule: %w", err)
}
r.updateState()
return nil
}
@@ -241,8 +223,6 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
r.updateState()
return nil
}
@@ -298,13 +278,8 @@ func (r *router) RemoveAllLegacyRouteRules() error {
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
}
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
@@ -319,31 +294,31 @@ func (r *router) Reset() error {
merr = multierror.Append(merr, err)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) cleanUpDefaultForwardRules() error {
if err := r.cleanJumpRules(); err != nil {
return fmt.Errorf("clean jump rules: %w", err)
err := r.cleanJumpRules()
if err != nil {
return err
}
log.Debug("flushing routing related tables")
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTNAT, tableNat},
{chainRTPRE, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
for _, chain := range []string{chainRTFWD, chainRTNAT} {
table := tableFilter
if chain == chainRTNAT {
table = tableNat
}
ok, err := r.iptablesClient.ChainExists(table, chain)
if err != nil {
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
log.Errorf("failed check chain %s, error: %v", chain, err)
return err
} else if ok {
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
if err != nil {
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
return err
}
}
}
@@ -352,58 +327,17 @@ func (r *router) cleanUpDefaultForwardRules() error {
}
func (r *router) createContainers() error {
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
} {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
for _, chain := range []string{chainRTFWD, chainRTNAT} {
if err := r.createAndSetupChain(chain); err != nil {
return fmt.Errorf("create chain %s: %v", chain, err)
}
}
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
return fmt.Errorf("insert established rule: %w", err)
return fmt.Errorf("insert established rule: %v", err)
}
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
return nil
}
func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %v", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %v", err)
}
r.rules["static-nat-return"] = rule2
return nil
return r.addJumpRules()
}
func (r *router) createAndSetupChain(chain string) error {
@@ -417,14 +351,10 @@ func (r *router) createAndSetupChain(chain string) error {
}
func (r *router) getTableForChain(chain string) string {
switch chain {
case chainRTNAT:
if chain == chainRTNAT {
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
return tableFilter
}
func (r *router) insertEstablishedRule(chain string) error {
@@ -442,39 +372,25 @@ func (r *router) insertEstablishedRule(chain string) error {
}
func (r *router) addJumpRules() error {
// Jump to NAT chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat jump rule: %v", err)
rule := []string{"-j", chainRTNAT}
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
}
r.rules[jumpNat] = natRule
// Jump to prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err)
}
r.rules[jumpPre] = preRule
r.rules[ipv4Nat] = rule
return nil
}
func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNat, jumpPre} {
if rule, exists := r.rules[ruleKey]; exists {
table := tableNat
chain := chainPOSTROUTING
if ruleKey == jumpPre {
table = tableMangle
chain = chainPREROUTING
}
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
}
delete(r.rules, ruleKey)
rule, found := r.rules[ipv4Nat]
if found {
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
}
}
return nil
}
@@ -482,35 +398,19 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
}
markValue := nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
"-s", pair.Source.String(),
"-d", pair.Destination.String(),
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
r.rules[ruleKey] = rule
return nil
}
@@ -518,41 +418,24 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
} else {
log.Debugf("marking rule %s not found", ruleKey)
log.Debugf("nat rule %s not found", ruleKey)
}
return nil
}
func (r *router) updateState() {
if r.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := r.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i"
if inverse {
intdir = "-o"
}
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {

View File

@@ -3,18 +3,18 @@
package iptables
import (
"fmt"
"context"
"net/netip"
"os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
)
func isIptablesSupported() bool {
@@ -30,29 +30,18 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
defer func() {
assert.NoError(t, manager.Reset(), "shouldn't return error")
_ = manager.Reset()
}()
// Now 5 rules:
// 1. established rule in forward chain
// 2. jump rule to NAT chain
// 3. jump rule to PRE chain
// 4. static outbound masquerade rule
// 5. static return masquerade rule
require.Len(t, manager.rules, 5, "should have created rules map")
require.Len(t, manager.rules, 2, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting jump rule should exist")
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
require.True(t, exists, "postrouting rule should exist")
pair := firewall.RouterPair{
ID: "abc",
@@ -60,15 +49,22 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
Destination: netip.MustParsePrefix("100.100.100.0/24"),
Masquerade: true,
}
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
err = manager.AddNatRule(pair)
require.NoError(t, err, "adding NAT rule should not return error")
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
}
func TestIptablesManager_AddNatRule(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
@@ -78,71 +74,56 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
assert.NoError(t, manager.Reset(), "shouldn't return error")
err := manager.Reset()
if err != nil {
log.Errorf("failed to reset iptables manager: %s", err)
}
}()
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "marking rule should be inserted")
require.NoError(t, err, "forwarding pair should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "marking rule should be created")
foundRule, found := manager.rules[natRuleKey]
require.True(t, found, "marking rule should exist in the map")
require.Equal(t, markingRule, foundRule, "stored marking rule should match")
require.True(t, exists, "nat rule should be created")
foundNatRule, foundNat := manager.rules[natRuleKey]
require.True(t, foundNat, "nat rule should exist in the map")
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "marking rule should not be created")
_, found := manager.rules[natRuleKey]
require.False(t, found, "marking rule should not exist in the map")
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
}
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "inverse marking rule should be created")
foundRule, found := manager.rules[inverseRuleKey]
require.True(t, found, "inverse marking rule should exist in the map")
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else {
require.False(t, exists, "inverse marking rule should not be created")
_, found := manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
}
})
}
}
func TestIptablesManager_RemoveNatRule(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
@@ -151,56 +132,45 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
assert.NoError(t, manager.Reset(), "shouldn't return error")
_ = manager.Reset()
}()
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule without error")
require.NoError(t, err, "shouldn't return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "marking rule should not exist")
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist")
_, found := manager.rules[natRuleKey]
require.False(t, found, "marking rule should not exist in the manager map")
require.False(t, found, "nat rule should exist in the manager map")
// Check inverse rule removal
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "income nat rule should not exist")
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "inverse marking rule should not exist")
_, found = manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
_, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
})
}
}
@@ -213,9 +183,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(iptablesClient, ifaceMock)
r, err := newRouter(context.Background(), iptablesClient, ifaceMock)
require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))
defer func() {
err := r.Reset()

View File

@@ -1,16 +1,14 @@
package iptables
import "encoding/json"
type ipList struct {
ips map[string]struct{}
}
func newIpList(ip string) *ipList {
func newIpList(ip string) ipList {
ips := make(map[string]struct{})
ips[ip] = struct{}{}
return &ipList{
return ipList{
ips: ips,
}
}
@@ -19,52 +17,27 @@ func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPs map[string]struct{} `json:"ips"`
}{
IPs: s.ips,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipList) UnmarshalJSON(data []byte) error {
temp := struct {
IPs map[string]struct{} `json:"ips"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ips = temp.IPs
if temp.IPs == nil {
temp.IPs = make(map[string]struct{})
}
return nil
}
type ipsetStore struct {
ipsets map[string]*ipList
ipsets map[string]ipList // ipsetName -> ruleset
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsets: make(map[string]*ipList),
ipsets: make(map[string]ipList),
}
}
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
s.ipsets[ipsetName] = list
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
s.ipsets[ipsetName] = ipList{}
delete(s.ipsets, ipsetName)
}
@@ -75,29 +48,3 @@ func (s *ipsetStore) ipsetNames() []string {
}
return names
}
// MarshalJSON implements json.Marshaler
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPSets map[string]*ipList `json:"ipsets"`
}{
IPSets: s.ipsets,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
temp := struct {
IPSets map[string]*ipList `json:"ipsets"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ipsets = temp.IPSets
if temp.IPSets == nil {
temp.IPSets = make(map[string]*ipList)
}
return nil
}

View File

@@ -1,70 +0,0 @@
package iptables
import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
}
func (s *ShutdownState) Name() string {
return "iptables_state"
}
func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create iptables manager: %w", err)
}
if s.RouteRules != nil {
ipt.router.rules = s.RouteRules
}
if s.RouteIPsetCounter != nil {
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
}
if s.ACLEntries != nil {
ipt.aclMgr.entries = s.ACLEntries
}
if s.ACLIPsetStore != nil {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
if err := ipt.Reset(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err)
}
return nil
}

View File

@@ -10,14 +10,11 @@ import (
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t"
PreroutingFormat = "netbird-prerouting-%s-%t"
NatFormat = "netbird-nat-%s-%t"
)
@@ -55,8 +52,6 @@ const (
// It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality
type Manager interface {
Init(stateManager *statemanager.Manager) error
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
@@ -96,7 +91,7 @@ type Manager interface {
SetLegacyManagement(legacy bool) error
// Reset firewall to the default state
Reset(stateManager *statemanager.Manager) error
Reset() error
// Flush the changes to firewall controller
Flush() error
@@ -137,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
sortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
@@ -175,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
return merged
}
// SortPrefixes sorts the given slice of netip.Prefix in place.
// sortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {
func sortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {

View File

@@ -5,18 +5,18 @@ import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"time"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/client/iface"
)
const (
@@ -27,8 +27,8 @@ const (
// filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
@@ -40,44 +40,54 @@ var (
)
type AclManager struct {
rConn *nftables.Conn
sConn *nftables.Conn
wgIface iFaceMapper
routingFwChainName string
rConn *nftables.Conn
sConn *nftables.Conn
wgIface iFaceMapper
routeingFwChainName string
workTable *nftables.Table
chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain
chainFwFilter *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for both type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, fmt.Errorf("create nf conn: %w", err)
return nil, err
}
return &AclManager{
rConn: &nftables.Conn{},
sConn: sConn,
wgIface: wgIface,
workTable: table,
routingFwChainName: routingFwChainName,
m := &AclManager{
rConn: &nftables.Conn{},
sConn: sConn,
wgIface: wgIface,
workTable: table,
routeingFwChainName: routeingFwChainName,
ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule),
}, nil
}
}
func (m *AclManager) init(workTable *nftables.Table) error {
m.workTable = workTable
return m.createDefaultChains()
err = m.createDefaultChains()
if err != nil {
return nil, err
}
return m, nil
}
// AddPeerFiltering rule to the firewall
@@ -439,10 +449,22 @@ func (m *AclManager) createDefaultChains() (err error) {
return err
}
// netbird-acl-output-filter
// type filter hook output priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err)
return err
}
// netbird-acl-forward-filter
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward() // to netbird-rt-fwd
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
if err != nil {
@@ -450,96 +472,10 @@ func (m *AclManager) createDefaultChains() (err error) {
return fmt.Errorf(flushError, err)
}
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
log.Errorf("failed to allow redirected traffic: %s", err)
}
return nil
}
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
preroutingChain := m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
})
m.addPreroutingRule(preroutingChain)
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
},
},
})
}
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
func (m *AclManager) addJumpRulesToRtForward() {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
@@ -549,13 +485,13 @@ func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.routingFwChainName,
Chain: m.routeingFwChainName,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
@@ -573,7 +509,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
return chain
}
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
@@ -605,6 +541,45 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
return nil
}
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
dstOp := expr.CmpOpNeq
expressions := []expr.Any{
&expr.Meta{Key: iifname, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: dstOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},

View File

@@ -14,8 +14,6 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
@@ -26,13 +24,6 @@ const (
chainNameInput = "INPUT"
)
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
@@ -44,70 +35,30 @@ type Manager struct {
}
// Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
m := &Manager{
rConn: &nftables.Conn{},
wgIface: wgIface,
}
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
var err error
m.router, err = newRouter(workTable, wgIface)
workTable, err := m.createWorkTable()
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
return nil, err
}
m.router, err = newRouter(context, workTable, wgIface)
if err != nil {
return nil, err
}
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
return nil, err
}
return m, nil
}
// Init nftables firewall manager
func (m *Manager) Init(stateManager *statemanager.Manager) error {
workTable, err := m.createWorkTable()
if err != nil {
return fmt.Errorf("create work table: %w", err)
}
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Reset() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
}
// persist early
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
return nil
}
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
@@ -199,7 +150,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
chain = c
break
}
@@ -232,84 +183,68 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
oldLegacy := m.router.legacyManagement
if oldLegacy != isLegacy {
m.router.legacyManagement = isLegacy
log.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to cleanup the legacy rules
if !isLegacy && oldLegacy {
if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed")
}
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.resetNetbirdInputRules(); err != nil {
return fmt.Errorf("reset netbird input rules: %v", err)
}
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err)
}
if err := m.cleanupNetbirdTables(); err != nil {
return fmt.Errorf("cleanup netbird tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
return fmt.Errorf("delete state: %v", err)
}
return nil
}
func (m *Manager) resetNetbirdInputRules() error {
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list chains: %w", err)
return fmt.Errorf("list of chains: %w", err)
}
m.deleteNetbirdInputRules(chains)
return nil
}
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
m.deleteMatchingRules(rules)
}
}
}
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
}
}
func (m *Manager) cleanupNetbirdTables() error {
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list tables: %w", err)
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset forward rules: %v", err)
}
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == tableNameNetbird {
m.rConn.DelTable(t)
}
}
return nil
return m.rConn.Flush()
}
// Flush rule/chain/set operations from the buffer
@@ -351,9 +286,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
&expr.Verdict{},
},
UserData: []byte(allowNetbirdInputRuleID),
}
@@ -382,33 +315,28 @@ func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: getEstablishedExprs(1),
Exprs: []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
conn.InsertRule(rule)
}
func getEstablishedExprs(register uint32) []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: register,
},
&expr.Bitwise{
SourceRegister: register,
DestRegister: register,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: register,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
}

View File

@@ -1,11 +1,10 @@
package nftables
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"os/exec"
"testing"
"time"
@@ -59,13 +58,12 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
// just check on the local interface
manager, err := Create(ifaceMock)
manager, err := Create(context.Background(), ifaceMock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3)
defer func() {
err = manager.Reset(nil)
err = manager.Reset()
require.NoError(t, err, "failed to reset")
time.Sleep(time.Second)
}()
@@ -111,7 +109,6 @@ func TestNftablesManager(t *testing.T) {
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
@@ -171,7 +168,7 @@ func TestNftablesManager(t *testing.T) {
// established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset(nil)
err = manager.Reset()
require.NoError(t, err, "failed to reset")
}
@@ -194,13 +191,12 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3)
defer func() {
if err := manager.Reset(nil); err != nil {
if err := manager.Reset(); err != nil {
t.Errorf("clear the manager state: %v", err)
}
time.Sleep(time.Second)
@@ -227,105 +223,3 @@ func TestNFtablesCreatePerformance(t *testing.T) {
})
}
}
func runIptablesSave(t *testing.T) (string, string) {
t.Helper()
var stdout, stderr bytes.Buffer
cmd := exec.Command("iptables-save")
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
require.NoError(t, err, "iptables-save failed to run")
return stdout.String(), stderr.String()
}
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
t.Helper()
// Check for any incompatibility warnings
require.NotContains(t,
stderr,
"incompatible",
"iptables-save produced compatibility warning. Full stderr: %s",
stderr,
)
// Verify standard tables are present
expectedTables := []string{
"*filter",
"*nat",
"*mangle",
}
for _, table := range expectedTables {
require.Contains(t,
stdout,
table,
"iptables-save output missing expected table: %s\nFull stdout: %s",
table,
stdout,
)
}
}
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Reset(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"test rule",
)
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"),
Destination: netip.MustParsePrefix("10.0.0.0/24"),
Masquerade: true,
}
err = manager.AddNatRule(pair)
require.NoError(t, err, "failed to add NAT rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}

View File

@@ -2,6 +2,7 @@ package nftables
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
@@ -9,8 +10,6 @@ import (
"net/netip"
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
@@ -21,12 +20,11 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingNat = "netbird-rt-nat"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif"
@@ -40,6 +38,8 @@ var (
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
@@ -52,8 +52,12 @@ type router struct {
legacyManagement bool
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
@@ -72,25 +76,20 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, fmt.Errorf("load filter table: %w", err)
return nil, err
}
}
return r, nil
}
func (r *router) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptForwardRules(); err != nil {
err = r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return nil
return r, err
}
// Reset cleans existing nftables default forward rules from the system
@@ -98,7 +97,40 @@ func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
return r.removeAcceptForwardRules()
return r.cleanUpDefaultForwardRules()
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
return r.conn.Flush()
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
@@ -117,6 +149,7 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
}
func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
@@ -124,42 +157,25 @@ func (r *router) createContainers() error {
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: &prio,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
r.acceptForwardRules()
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
}
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
if err := r.refreshRulesMap(); err != nil {
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.conn.Flush(); err != nil {
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
@@ -172,7 +188,6 @@ func (r *router) AddRouteFiltering(
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
@@ -233,18 +248,9 @@ func (r *router) AddRouteFiltering(
UserData: []byte(ruleKey),
}
rule = r.conn.AddRule(rule)
r.rules[string(ruleKey)] = r.conn.AddRule(rule)
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
r.rules[string(ruleKey)] = rule
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil
return ruleKey, r.conn.Flush()
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
@@ -282,10 +288,6 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return nil
}
if nftRule.Handle == 0 {
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
@@ -437,149 +439,44 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
op := expr.CmpOpEq
dir := expr.MetaKeyIIFNAME
if pair.Inverse {
op = expr.CmpOpNeq
dir = expr.MetaKeyOIFNAME
}
intf := ifname(r.wgIface.Name())
exprs := []expr.Any{
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Key: dir,
Register: 1,
},
&expr.Cmp{
Op: op,
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
Data: intf,
},
}
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(markValue),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
&expr.Counter{}, &expr.Masq{},
)
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
return fmt.Errorf("remove routing rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNamePrerouting],
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
UserData: []byte(ruleKey),
})
return nil
}
// addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() error {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
})
// Second masquerade rule for traffic going out through WireGuard interface
exprs2 := []expr.Any{
// Match on the second fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
},
// Match WireGuard interface
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
return nil
}
@@ -656,10 +553,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
}
}
return nberrors.FormatErrorOrNil(merr)
}
@@ -668,60 +562,19 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (r *router) acceptForwardRules() error {
func (r *router) acceptForwardRules() {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return nil
return
}
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
// filter table exists but iptables is not
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptForwardRulesNftables()
}
return r.acceptForwardRulesIptables(ipt)
}
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
} else {
log.Debugf("added iptables rule: %v", rule)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) getAcceptForwardRules() [][]string {
intf := r.wgIface.Name()
return [][]string{
{"-i", intf, "-j", "ACCEPT"},
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
}
}
func (r *router) acceptForwardRulesNftables() error {
intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameForward,
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
@@ -741,15 +594,6 @@ func (r *router) acceptForwardRulesNftables() error {
}
r.conn.InsertRule(iifRule)
oifExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
}
// Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{
Table: r.filterTable,
@@ -760,86 +604,50 @@ func (r *router) acceptForwardRulesNftables() error {
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: append(oifExprs, getEstablishedExprs(2)...),
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 2,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
return nil
}
func (r *router) removeAcceptForwardRules() error {
if r.filterTable == nil {
return nil
}
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptForwardRulesNftables()
}
return r.removeAcceptForwardRulesIptables(ipt)
}
func (r *router) removeAcceptForwardRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// RemoveNatRule removes the prerouting mark rule
// RemoveNatRule removes a nftables rule pair from nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
return fmt.Errorf("remove inverse nat rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -850,24 +658,25 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
log.Debugf("nftables: removed rules for %s", pair.Destination)
return nil
}
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
log.Debugf("nftables: nat rule %s not found", ruleKey)
}
return nil

View File

@@ -3,6 +3,7 @@
package nftables
import (
"context"
"encoding/binary"
"net/netip"
"os/exec"
@@ -10,7 +11,6 @@ import (
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -33,87 +33,87 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
manager, err := newRouter(context.TODO(), table, ifaceMock)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
rtr := manager.router
err = rtr.AddNatRule(testCase.InputPair)
defer func(manager *router) {
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
require.NoError(t, err, "shouldn't return error")
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted")
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
})
defer func(manager *router, pair firewall.RouterPair) {
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
}(manager, testCase.InputPair)
if testCase.InputPair.Masquerade {
// Build expected expressions for connection tracking
conntrackExprs := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
// Build interface matching expression
ifaceExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
}
)
// Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
// Combine all expressions in the correct order
// nolint:gocritic
testingExpression := append(conntrackExprs, ifaceExprs...)
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
found := 0
for _, chain := range rtr.chains {
if chain.Name == chainNamePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
// Compare expressions up to the mark setting expressions
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1
}
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
if testCase.InputPair.Masquerade {
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
})
}
}
@@ -123,66 +123,67 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
manager, err := newRouter(context.TODO(), table, ifaceMock)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) {
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(natRuleKey),
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
rtr := manager.router
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
// First add the NAT rule using the router's method
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule")
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
// Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
}
require.True(t, found, "NAT rule should exist before removal")
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
// Now remove the rule
err = rtr.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error when removing rule")
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
// Verify the rule was removed
found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
}
require.False(t, found, "NAT rule should not exist after removal")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
// Verify the static postrouting rules still exist
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
require.NoError(t, err, "should list postrouting rules")
foundCounter := false
for _, rule := range rules {
for _, e := range rule.Exprs {
if _, ok := e.(*expr.Counter); ok {
foundCounter = true
break
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
}
if foundCounter {
break
}
}
require.True(t, foundCounter, "static postrouting rule should remain")
})
}
}
@@ -197,9 +198,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
r, err := newRouter(context.Background(), workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func(r *router) {
require.NoError(t, r.Reset(), "Failed to reset rules")
@@ -314,10 +314,6 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
})
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
@@ -350,6 +346,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
// Verify actual nftables rule content
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
})
}
}
@@ -364,9 +364,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
r, err := newRouter(context.Background(), workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")

View File

@@ -1,47 +0,0 @@
package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}
func (s *ShutdownState) Name() string {
return "nftables_state"
}
func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create nftables manager: %w", err)
}
if err := nft.Reset(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err)
}
return nil
}

View File

@@ -2,36 +2,16 @@
package uspfilter
import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager)
return m.nativeFirewall.Reset()
}
return nil
}

View File

@@ -6,9 +6,6 @@ import (
"syscall"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type action string
@@ -20,28 +17,13 @@ const (
)
// Reset firewall to the default state
func (m *Manager) Reset(*statemanager.Manager) error {
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if !isWindowsFirewallReachable() {
return nil
}

View File

@@ -1,137 +0,0 @@
// common.go
package conntrack
import (
"net"
"sync"
"sync/atomic"
"time"
)
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64 // Unix nano for atomic access
established atomic.Bool
}
// these small methods will be inlined by the compiler
// UpdateLastSeen safely updates the last seen timestamp
func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano())
}
// IsEstablished safely checks if connection is established
func (b *BaseConnTrack) IsEstablished() bool {
return b.established.Load()
}
// SetEstablished safely sets the established state
func (b *BaseConnTrack) SetEstablished(state bool) {
b.established.Store(state)
}
// GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load())
}
// timeoutExceeded checks if the connection has exceeded the given timeout
func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
lastSeen := time.Unix(0, b.lastSeen.Load())
return time.Since(lastSeen) > timeout
}
// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte
// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}
// ConnKey uniquely identifies a connection
type ConnKey struct {
SrcIP IPAddr
DstIP IPAddr
SrcPort uint16
DstPort uint16
}
// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
return ConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
}
// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}
// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}
// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}
// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}
// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
}

View File

@@ -1,115 +0,0 @@
package conntrack
import (
"net"
"testing"
)
func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = MakeIPAddr(ip)
}
})
b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := MakeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
}
})
b.Run("IPPool", func(b *testing.B) {
pool := NewPreallocatedIPs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := pool.Get()
pool.Put(ip)
}
})
}
func BenchmarkAtomicOperations(b *testing.B) {
conn := &BaseConnTrack{}
b.Run("UpdateLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.UpdateLastSeen()
}
})
b.Run("IsEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.IsEstablished()
}
})
b.Run("SetEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.SetEstablished(i%2 == 0)
}
})
b.Run("GetLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.GetLastSeen()
}
})
}
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
// Generate different IPs
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
}
}
})
b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()
// Generate different IPs
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
}
}
})
}

View File

@@ -1,170 +0,0 @@
package conntrack
import (
"net"
"sync"
"time"
"github.com/google/gopacket/layers"
)
const (
// DefaultICMPTimeout is the default timeout for ICMP connections
DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second
)
// ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct {
// Supports both IPv4 and IPv6
SrcIP [16]byte
DstIP [16]byte
Sequence uint16 // ICMP sequence number
ID uint16 // ICMP identifier
}
// ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct {
BaseConnTrack
Sequence uint16
ID uint16
}
// ICMPTracker manages ICMP connection states
type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
if timeout == 0 {
timeout = DefaultICMPTimeout
}
tracker := &ICMPTracker{
connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
}
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound records an outbound ICMP Echo Request
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
key := makeICMPKey(srcIP, dstIP, id, seq)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
},
ID: id,
Sequence: seq,
}
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
}
t.mutex.Unlock()
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
switch icmpType {
case uint8(layers.ICMPv4TypeDestinationUnreachable),
uint8(layers.ICMPv4TypeTimeExceeded):
return true
case uint8(layers.ICMPv4TypeEchoReply):
// continue processing
default:
return false
}
key := makeICMPKey(dstIP, srcIP, id, seq)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
return false
}
if conn.timeoutExceeded(t.timeout) {
return false
}
return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id &&
conn.Sequence == seq
}
func (t *ICMPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}
func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
// makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
return ICMPConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
ID: id,
Sequence: seq,
}
}

View File

@@ -1,39 +0,0 @@
package conntrack
import (
"net"
"testing"
)
func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
}
})
}

View File

@@ -1,352 +0,0 @@
package conntrack
// TODO: Send RST packets for invalid/timed-out connections
import (
"net"
"sync"
"time"
)
const (
// MSL (Maximum Segment Lifetime) is typically 2 minutes
MSL = 2 * time.Minute
// TimeWaitTimeout (TIME-WAIT) should last 2*MSL
TimeWaitTimeout = 2 * MSL
)
const (
TCPSyn uint8 = 0x02
TCPAck uint8 = 0x10
TCPFin uint8 = 0x01
TCPRst uint8 = 0x04
TCPPush uint8 = 0x08
TCPUrg uint8 = 0x20
)
const (
// DefaultTCPTimeout is the default timeout for established TCP connections
DefaultTCPTimeout = 3 * time.Hour
// TCPHandshakeTimeout is timeout for TCP handshake completion
TCPHandshakeTimeout = 60 * time.Second
// TCPCleanupInterval is how often we check for stale connections
TCPCleanupInterval = 5 * time.Minute
)
// TCPState represents the state of a TCP connection
type TCPState int
const (
TCPStateNew TCPState = iota
TCPStateSynSent
TCPStateSynReceived
TCPStateEstablished
TCPStateFinWait1
TCPStateFinWait2
TCPStateClosing
TCPStateTimeWait
TCPStateCloseWait
TCPStateLastAck
TCPStateClosed
)
// TCPConnKey uniquely identifies a TCP connection
type TCPConnKey struct {
SrcIP [16]byte
DstIP [16]byte
SrcPort uint16
DstPort uint16
}
// TCPConnTrack represents a TCP connection state
type TCPConnTrack struct {
BaseConnTrack
State TCPState
sync.RWMutex
}
// TCPTracker manages TCP connection states
type TCPTracker struct {
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
done chan struct{}
timeout time.Duration
ipPool *PreallocatedIPs
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration) *TCPTracker {
tracker := &TCPTracker{
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}),
timeout: timeout,
ipPool: NewPreallocatedIPs(),
}
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound processes an outbound TCP packet and updates connection state
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
// Create key before lock
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
State: TCPStateNew,
}
conn.lastSeen.Store(now)
conn.established.Store(false)
t.connections[key] = conn
}
t.mutex.Unlock()
// Lock individual connection for state update
conn.Lock()
t.updateState(conn, flags, true)
conn.Unlock()
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
return false
}
// Handle RST packets
if flags&TCPRst != 0 {
conn.Lock()
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
return true
}
conn.Unlock()
return false
}
conn.Lock()
t.updateState(conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock()
return isEstablished || isValidState
}
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
// Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetEstablished(false)
return
}
switch conn.State {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent
}
case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound {
conn.State = TCPStateSynReceived
} else {
// Simultaneous open
conn.State = TCPStateEstablished
conn.SetEstablished(true)
}
}
case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
conn.State = TCPStateEstablished
conn.SetEstablished(true)
}
case TCPStateEstablished:
if flags&TCPFin != 0 {
if isOutbound {
conn.State = TCPStateFinWait1
} else {
conn.State = TCPStateCloseWait
}
conn.SetEstablished(false)
}
case TCPStateFinWait1:
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
// Simultaneous close - both sides sent FIN
conn.State = TCPStateClosing
case flags&TCPFin != 0:
conn.State = TCPStateFinWait2
case flags&TCPAck != 0:
conn.State = TCPStateFinWait2
}
case TCPStateFinWait2:
if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait
}
case TCPStateClosing:
if flags&TCPAck != 0 {
conn.State = TCPStateTimeWait
// Keep established = false from previous state
}
case TCPStateCloseWait:
if flags&TCPFin != 0 {
conn.State = TCPStateLastAck
}
case TCPStateLastAck:
if flags&TCPAck != 0 {
conn.State = TCPStateClosed
}
case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine
}
}
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
switch state {
case TCPStateNew:
return flags&TCPSyn != 0 && flags&TCPAck == 0
case TCPStateSynSent:
return flags&TCPSyn != 0 && flags&TCPAck != 0
case TCPStateSynReceived:
return flags&TCPAck != 0
case TCPStateEstablished:
if flags&TCPRst != 0 {
return true
}
return flags&TCPAck != 0
case TCPStateFinWait1:
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateFinWait2:
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateClosing:
// In CLOSING state, we should accept the final ACK
return flags&TCPAck != 0
case TCPStateTimeWait:
// In TIME_WAIT, we might see retransmissions
return flags&TCPAck != 0
case TCPStateCloseWait:
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateLastAck:
return flags&TCPAck != 0
case TCPStateClosed:
// Accept retransmitted ACKs in closed state
// This is important because the final ACK might be lost
// and the peer will retransmit their FIN-ACK
return flags&TCPAck != 0
}
return false
}
func (t *TCPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}
func (t *TCPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
var timeout time.Duration
switch {
case conn.State == TCPStateTimeWait:
timeout = TimeWaitTimeout
case conn.IsEstablished():
timeout = t.timeout
default:
timeout = TCPHandshakeTimeout
}
lastSeen := conn.GetLastSeen()
if time.Since(lastSeen) > timeout {
// Return IPs to pool
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
// Clean up all remaining IPs
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
func isValidFlagCombination(flags uint8) bool {
// Invalid: SYN+FIN
if flags&TCPSyn != 0 && flags&TCPFin != 0 {
return false
}
// Invalid: RST with SYN or FIN
if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) {
return false
}
return true
}

View File

@@ -1,308 +0,0 @@
package conntrack
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Security Tests", func(t *testing.T) {
tests := []struct {
name string
flags uint8
wantDrop bool
desc string
}{
{
name: "Block unsolicited SYN-ACK",
flags: TCPSyn | TCPAck,
wantDrop: true,
desc: "Should block SYN-ACK without prior SYN",
},
{
name: "Block invalid SYN-FIN",
flags: TCPSyn | TCPFin,
wantDrop: true,
desc: "Should block invalid SYN-FIN combination",
},
{
name: "Block unsolicited RST",
flags: TCPRst,
wantDrop: true,
desc: "Should block RST without connection",
},
{
name: "Block unsolicited ACK",
flags: TCPAck,
wantDrop: true,
desc: "Should block ACK without connection",
},
{
name: "Block data without connection",
flags: TCPAck | TCPPush,
wantDrop: true,
desc: "Should block data without established connection",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
})
}
})
t.Run("Connection Flow Tests", func(t *testing.T) {
tests := []struct {
name string
test func(*testing.T)
desc string
}{
{
name: "Normal Handshake",
test: func(t *testing.T) {
t.Helper()
// Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
// Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
// Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
require.True(t, valid, "Data should be allowed after handshake")
},
},
{
name: "Normal Close",
test: func(t *testing.T) {
t.Helper()
// First establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "FIN should be allowed")
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
},
},
{
name: "RST During Connection",
test: func(t *testing.T) {
t.Helper()
// First establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets
// The connection will be cleaned up by timeout
},
},
{
name: "Simultaneous Close",
test: func(t *testing.T) {
t.Helper()
// First establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "Final ACKs should be allowed")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout)
tt.test(t)
})
}
})
}
func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
tests := []struct {
name string
setupState func()
sendRST func()
wantValid bool
desc string
}{
{
name: "RST in established",
setupState: func() {
// Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
},
wantValid: true,
desc: "Should accept RST for established connection",
},
{
name: "RST without connection",
setupState: func() {},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
},
wantValid: false,
desc: "Should reject RST without connection",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupState()
tt.sendRST()
// Verify connection state is as expected
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn := tracker.connections[key]
if tt.wantValid {
require.NotNil(t, conn)
require.Equal(t, TCPStateClosed, conn.State)
require.False(t, conn.IsEstablished())
}
})
}
}
// Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
}

View File

@@ -1,158 +0,0 @@
package conntrack
import (
"net"
"sync"
"time"
)
const (
// DefaultUDPTimeout is the default timeout for UDP connections
DefaultUDPTimeout = 30 * time.Second
// UDPCleanupInterval is how often we check for stale connections
UDPCleanupInterval = 15 * time.Second
)
// UDPConnTrack represents a UDP connection state
type UDPConnTrack struct {
BaseConnTrack
}
// UDPTracker manages UDP connection states
type UDPTracker struct {
connections map[ConnKey]*UDPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
}
// NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration) *UDPTracker {
if timeout == 0 {
timeout = DefaultUDPTimeout
}
tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
}
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
}
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
}
t.mutex.Unlock()
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
return false
}
if conn.timeoutExceeded(t.timeout) {
return false
}
return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}
func (t *UDPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
// GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := t.connections[key]
if !exists {
return nil, false
}
return conn, true
}
// Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration {
return t.timeout
}

View File

@@ -1,243 +0,0 @@
package conntrack
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewUDPTracker(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
wantTimeout time.Duration
}{
{
name: "with custom timeout",
timeout: 1 * time.Minute,
wantTimeout: 1 * time.Minute,
},
{
name: "with zero timeout uses default",
timeout: 0,
wantTimeout: DefaultUDPTimeout,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout)
assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.done)
})
}
}
func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
// Verify connection was tracked
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := tracker.connections[key]
require.True(t, exists)
assert.True(t, conn.SourceIP.Equal(srcIP))
assert.True(t, conn.DestIP.Equal(dstIP))
assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort)
assert.True(t, conn.IsEstablished())
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
}
func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1 * time.Second)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
// Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tests := []struct {
name string
srcIP net.IP
dstIP net.IP
srcPort uint16
dstPort uint16
sleep time.Duration
want bool
}{
{
name: "valid inbound response",
srcIP: dstIP, // Original destination is now source
dstIP: srcIP, // Original source is now destination
srcPort: dstPort, // Original destination port is now source
dstPort: srcPort, // Original source port is now destination
sleep: 0,
want: true,
},
{
name: "invalid source IP",
srcIP: net.ParseIP("192.168.1.4"),
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
want: false,
},
{
name: "invalid destination IP",
srcIP: dstIP,
dstIP: net.ParseIP("192.168.1.4"),
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
want: false,
},
{
name: "invalid source port",
srcIP: dstIP,
dstIP: srcIP,
srcPort: 54321,
dstPort: srcPort,
sleep: 0,
want: false,
},
{
name: "invalid destination port",
srcIP: dstIP,
dstIP: srcIP,
srcPort: dstPort,
dstPort: 54321,
sleep: 0,
want: false,
},
{
name: "expired connection",
srcIP: dstIP,
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
sleep: 2 * time.Second, // Longer than tracker timeout
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.sleep > 0 {
time.Sleep(tt.sleep)
}
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
assert.Equal(t, tt.want, got)
})
}
}
func TestUDPTracker_Cleanup(t *testing.T) {
// Use shorter intervals for testing
timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond
// Create tracker with custom cleanup interval
tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
}
// Start cleanup routine
go tracker.cleanupRoutine()
// Add some connections
connections := []struct {
srcIP net.IP
dstIP net.IP
srcPort uint16
dstPort uint16
}{
{
srcIP: net.ParseIP("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"),
srcPort: 12345,
dstPort: 53,
},
{
srcIP: net.ParseIP("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"),
srcPort: 12346,
dstPort: 53,
},
}
for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
}
// Verify initial connections
assert.Len(t, tracker.connections, 2)
// Wait for connection timeout and cleanup interval
time.Sleep(timeout + 2*cleanupInterval)
tracker.mutex.RLock()
connCount := len(tracker.connections)
tracker.mutex.RUnlock()
// Verify connections were cleaned up
assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up")
// Properly close the tracker
tracker.Close()
}
func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
}
})
}

View File

@@ -4,8 +4,6 @@ import (
"fmt"
"net"
"net/netip"
"os"
"strconv"
"sync"
"github.com/google/gopacket"
@@ -14,16 +12,12 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const layerTypeAll = 0
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
var (
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
)
@@ -47,11 +41,6 @@ type Manager struct {
nativeFirewall firewall.Manager
mutex sync.RWMutex
stateful bool
udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker
}
// decoder for packages
@@ -83,8 +72,6 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
}
func create(iface IFaceMapper) (*Manager, error) {
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
m := &Manager{
decoders: sync.Pool{
New: func() any {
@@ -102,16 +89,6 @@ func create(iface IFaceMapper) (*Manager, error) {
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
wgIface: iface,
stateful: !disableConntrack,
}
// Only initialize trackers if stateful mode is enabled
if disableConntrack {
log.Info("conntrack is disabled")
} else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if err := iface.SetFilter(m); err != nil {
@@ -120,10 +97,6 @@ func create(iface IFaceMapper) (*Manager, error) {
return m, nil
}
func (m *Manager) Init(*statemanager.Manager) error {
return nil
}
func (m *Manager) IsServerRouteSupported() bool {
if m.nativeFirewall == nil {
return false
@@ -217,7 +190,7 @@ func (m *Manager) AddPeerFiltering(
return []firewall.Rule{&r}, nil
}
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errRouteNotSupported
}
@@ -259,11 +232,8 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
}
// SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.SetLegacyManagement(isLegacy)
func (m *Manager) SetLegacyManagement(_ bool) error {
return nil
}
// Flush doesn't need to be implemented for this manager
@@ -271,230 +241,78 @@ func (m *Manager) Flush() error { return nil }
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool {
return m.processOutgoingHooks(packetData)
return m.dropFilter(packetData, m.outgoingRules, false)
}
// DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte) bool {
return m.dropFilter(packetData, m.incomingRules)
return m.dropFilter(packetData, m.incomingRules, true)
}
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
// dropFilter implements same logic for booth direction of the traffic
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
return false
}
if len(d.decoded) < 2 {
return false
}
srcIP, dstIP := m.extractIPs(d)
if srcIP == nil {
return false
}
// Always process UDP hooks
if d.decoded[1] == layers.LayerTypeUDP {
// Track UDP state only if enabled
if m.stateful {
m.trackUDPOutbound(d, srcIP, dstIP)
}
return m.checkUDPHooks(d, dstIP, packetData)
}
// Track other protocols only if stateful mode is enabled
if m.stateful {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.trackTCPOutbound(d, srcIP, dstIP)
case layers.LayerTypeICMPv4:
m.trackICMPOutbound(d, srcIP, dstIP)
}
}
return false
}
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
return d.ip4.SrcIP, d.ip4.DstIP
case layers.LayerTypeIPv6:
return d.ip6.SrcIP, d.ip6.DstIP
default:
return nil, nil
}
}
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
flags,
)
}
func getTCPFlags(tcp *layers.TCP) uint8 {
var flags uint8
if tcp.SYN {
flags |= conntrack.TCPSyn
}
if tcp.ACK {
flags |= conntrack.TCPAck
}
if tcp.FIN {
flags |= conntrack.TCPFin
}
if tcp.RST {
flags |= conntrack.TCPRst
}
if tcp.PSH {
flags |= conntrack.TCPPush
}
if tcp.URG {
flags |= conntrack.TCPUrg
}
return flags
}
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
m.udpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
)
}
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
if rules, exists := m.outgoingRules[ipKey]; exists {
for _, rule := range rules {
if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) {
return rule.udpHook(packetData)
}
}
}
}
return false
}
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
m.icmpTracker.TrackOutbound(
srcIP,
dstIP,
d.icmp4.Id,
d.icmp4.Seq,
)
}
}
// dropFilter implements filtering logic for incoming packets
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if !m.isValidPacket(d, packetData) {
return true
}
srcIP, dstIP := m.extractIPs(d)
if srcIP == nil {
log.Errorf("unknown layer: %v", d.decoded[0])
return true
}
if !m.isWireguardTraffic(srcIP, dstIP) {
return false
}
// Check connection state only if enabled
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
return false
}
return m.applyRules(srcIP, packetData, rules, d)
}
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
log.Tracef("couldn't decode layer, err: %s", err)
return false
return true
}
if len(d.decoded) < 2 {
log.Tracef("not enough levels in network packet")
return false
}
return true
}
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound(
srcIP,
dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp),
)
case layers.LayerTypeUDP:
return m.udpTracker.IsValidInbound(
srcIP,
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
)
case layers.LayerTypeICMPv4:
return m.icmpTracker.IsValidInbound(
srcIP,
dstIP,
d.icmp4.Id,
d.icmp4.Seq,
d.icmp4.TypeCode.Type(),
)
// TODO: ICMPv6
return true
}
return false
}
ipLayer := d.decoded[0]
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
switch ipLayer {
case layers.LayerTypeIPv4:
if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) {
return false
}
case layers.LayerTypeIPv6:
if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) {
return false
}
default:
log.Errorf("unknown layer: %v", d.decoded[0])
return true
}
var ip net.IP
switch ipLayer {
case layers.LayerTypeIPv4:
if isIncomingPacket {
ip = d.ip4.SrcIP
} else {
ip = d.ip4.DstIP
}
case layers.LayerTypeIPv6:
if isIncomingPacket {
ip = d.ip6.SrcIP
} else {
ip = d.ip6.DstIP
}
}
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["::"], d)
if ok {
return filter
}
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
return filter
}
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
return filter
}
// Default policy: DROP ALL
// default policy is DROP ALL
return true
}

View File

@@ -1,998 +0,0 @@
package uspfilter
import (
"fmt"
"math/rand"
"net"
"os"
"strings"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface/device"
)
// generateRandomIPs generates n different random IPs in the 100.64.0.0/10 range
func generateRandomIPs(n int) []net.IP {
ips := make([]net.IP, n)
seen := make(map[string]bool)
for i := 0; i < n; {
ip := make(net.IP, 4)
ip[0] = 100
ip[1] = byte(64 + rand.Intn(63)) // 64-126
ip[2] = byte(rand.Intn(256))
ip[3] = byte(1 + rand.Intn(254)) // avoid .0 and .255
key := ip.String()
if !seen[key] {
ips[i] = ip
seen[key] = true
i++
}
}
return ips
}
func generatePacket(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
b.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: protocol,
}
var transportLayer gopacket.SerializableLayer
switch protocol {
case layers.IPProtocolTCP:
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
}
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
transportLayer = tcp
case layers.IPProtocolUDP:
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv4))
transportLayer = udp
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
require.NoError(b, err)
return buf.Bytes()
}
// BenchmarkCoreFiltering focuses on the essential performance comparisons between
// stateful and stateless filtering approaches
func BenchmarkCoreFiltering(b *testing.B) {
scenarios := []struct {
name string
stateful bool
setupFunc func(*Manager)
desc string
}{
{
name: "stateless_single_allow_all",
stateful: false,
setupFunc: func(m *Manager) {
// Single rule allowing all traffic
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "allow all")
require.NoError(b, err)
},
desc: "Baseline: Single 'allow all' rule without connection tracking",
},
{
name: "stateful_no_rules",
stateful: true,
setupFunc: func(m *Manager) {
// No explicit rules - rely purely on connection tracking
},
desc: "Pure connection tracking without any rules",
},
{
name: "stateless_explicit_return",
stateful: false,
setupFunc: func(m *Manager) {
// Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
&fw.Port{Values: []int{1024 + i}},
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return")
require.NoError(b, err)
}
},
desc: "Explicit rules matching return traffic patterns without state",
},
{
name: "stateful_with_established",
stateful: true,
setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
fw.RuleDirectionIN, fw.ActionDrop, "", "default drop")
require.NoError(b, err)
},
desc: "Connection tracking with established connections",
},
}
// Test both TCP and UDP
protocols := []struct {
name string
proto layers.IPProtocol
}{
{"TCP", layers.IPProtocolTCP},
{"UDP", layers.IPProtocolUDP},
}
for _, sc := range scenarios {
for _, proto := range protocols {
b.Run(fmt.Sprintf("%s_%s", sc.name, proto.name), func(b *testing.B) {
// Configure stateful/stateless mode
if !sc.stateful {
require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1"))
} else {
require.NoError(b, os.Setenv("NB_CONNTRACK_TIMEOUT", "1m"))
}
// Create manager and basic setup
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Apply scenario-specific setup
sc.setupFunc(manager)
// Generate test packets
srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0]
srcPort := uint16(1024 + b.N%60000)
dstPort := uint16(80)
outbound := generatePacket(b, srcIP, dstIP, srcPort, dstPort, proto.proto)
inbound := generatePacket(b, dstIP, srcIP, dstPort, srcPort, proto.proto)
// For stateful scenarios, establish the connection
if sc.stateful {
manager.processOutgoingHooks(outbound)
}
// Measure inbound packet processing
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules)
}
})
}
}
}
// BenchmarkStateScaling measures how performance scales with connection table size
func BenchmarkStateScaling(b *testing.B) {
connCounts := []int{100, 1000, 10000, 100000}
for _, count := range connCounts {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Pre-populate connection table
srcIPs := generateRandomIPs(count)
dstIPs := generateRandomIPs(count)
for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound)
}
// Test packet
testOut := generatePacket(b, srcIPs[0], dstIPs[0], 1024, 80, layers.IPProtocolTCP)
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection
manager.processOutgoingHooks(testOut)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(testIn, manager.incomingRules)
}
})
}
}
// BenchmarkEstablishmentOverhead measures the overhead of connection establishment
func BenchmarkEstablishmentOverhead(b *testing.B) {
scenarios := []struct {
name string
established bool
}{
{"established", true},
{"new", false},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0]
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established {
manager.processOutgoingHooks(outbound)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules)
}
})
}
}
// BenchmarkRoutedNetworkReturn compares approaches for handling routed network return traffic
func BenchmarkRoutedNetworkReturn(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
state string // "new", "established", "post_handshake" (TCP only)
setupFunc func(*Manager)
genPackets func(net.IP, net.IP) ([]byte, []byte) // generates appropriate packets for the scenario
desc string
}{
{
name: "allow_non_wg_tcp_new",
proto: layers.IPProtocolTCP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP),
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
},
desc: "Allow non-WG: TCP new connection",
},
{
name: "allow_non_wg_tcp_established",
proto: layers.IPProtocolTCP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
// Generate packets with ACK flag for established connection
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)),
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck))
},
desc: "Allow non-WG: TCP established connection",
},
{
name: "allow_non_wg_udp_new",
proto: layers.IPProtocolUDP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
},
desc: "Allow non-WG: UDP new connection",
},
{
name: "allow_non_wg_udp_established",
proto: layers.IPProtocolUDP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
},
desc: "Allow non-WG: UDP established connection",
},
{
name: "stateful_tcp_new",
proto: layers.IPProtocolTCP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP),
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
},
desc: "Stateful: TCP new connection",
},
{
name: "stateful_tcp_established",
proto: layers.IPProtocolTCP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
// Generate established TCP packets (ACK flag)
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)),
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck))
},
desc: "Stateful: TCP established connection",
},
{
name: "stateful_tcp_post_handshake",
proto: layers.IPProtocolTCP,
state: "post_handshake",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
// Generate packets with PSH+ACK flags for data transfer
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPPush|conntrack.TCPAck))
},
desc: "Stateful: TCP post-handshake data transfer",
},
{
name: "stateful_udp_new",
proto: layers.IPProtocolUDP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
},
desc: "Stateful: UDP new connection",
},
{
name: "stateful_udp_established",
proto: layers.IPProtocolUDP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
},
desc: "Stateful: UDP established connection",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
// Setup scenario
sc.setupFunc(manager)
// Use IPs outside WG range for routed network simulation
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("8.8.8.8")
outbound, inbound := sc.genPackets(srcIP, dstIP)
// For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound)
// For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" {
// SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules)
// ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules)
}
})
}
}
var scenarios = []struct {
name string
stateful bool // Whether conntrack is enabled
rules bool // Whether to add return traffic rules
routed bool // Whether to test routed network traffic
connCount int // Number of concurrent connections
desc string
}{
{
name: "stateless_with_rules_100conns",
stateful: false,
rules: true,
routed: false,
connCount: 100,
desc: "Pure stateless with return traffic rules, 100 conns",
},
{
name: "stateless_with_rules_1000conns",
stateful: false,
rules: true,
routed: false,
connCount: 1000,
desc: "Pure stateless with return traffic rules, 1000 conns",
},
{
name: "stateful_no_rules_100conns",
stateful: true,
rules: false,
routed: false,
connCount: 100,
desc: "Pure stateful tracking without rules, 100 conns",
},
{
name: "stateful_no_rules_1000conns",
stateful: true,
rules: false,
routed: false,
connCount: 1000,
desc: "Pure stateful tracking without rules, 1000 conns",
},
{
name: "stateful_with_rules_100conns",
stateful: true,
rules: true,
routed: false,
connCount: 100,
desc: "Combined stateful + rules (current implementation), 100 conns",
},
{
name: "stateful_with_rules_1000conns",
stateful: true,
rules: true,
routed: false,
connCount: 1000,
desc: "Combined stateful + rules (current implementation), 1000 conns",
},
{
name: "routed_network_100conns",
stateful: true,
rules: false,
routed: true,
connCount: 100,
desc: "Routed network traffic (non-WG), 100 conns",
},
{
name: "routed_network_1000conns",
stateful: true,
rules: false,
routed: true,
connCount: 1000,
desc: "Routed network traffic (non-WG), 1000 conns",
},
}
// BenchmarkLongLivedConnections tests performance with realistic TCP traffic patterns
func BenchmarkLongLivedConnections(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
// Configure stateful/stateless mode
if !sc.stateful {
b.Setenv("NB_DISABLE_CONNTRACK", "1")
} else {
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
// Generate IPs for connections
srcIPs := make([]net.IP, sc.connCount)
dstIPs := make([]net.IP, sc.connCount)
for i := 0; i < sc.connCount; i++ {
if sc.routed {
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
} else {
srcIPs[i] = generateRandomIPs(1)[0]
dstIPs[i] = generateRandomIPs(1)[0]
}
}
// Create established connections
for i := 0; i < sc.connCount; i++ {
// Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules)
// ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack)
}
// Prepare test packets simulating bidirectional traffic
inPackets := make([][]byte, sc.connCount)
outPackets := make([][]byte, sc.connCount)
for i := 0; i < sc.connCount; i++ {
// Server -> Client (inbound)
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
// Client -> Server (outbound)
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
connIdx := i % sc.connCount
// Simulate bidirectional traffic
// First outbound data
manager.processOutgoingHooks(outPackets[connIdx])
// Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
}
})
}
}
// BenchmarkShortLivedConnections tests performance with many short-lived connections
func BenchmarkShortLivedConnections(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
// Configure stateful/stateless mode
if !sc.stateful {
b.Setenv("NB_DISABLE_CONNTRACK", "1")
} else {
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
// Generate IPs for connections
srcIPs := make([]net.IP, sc.connCount)
dstIPs := make([]net.IP, sc.connCount)
for i := 0; i < sc.connCount; i++ {
if sc.routed {
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
} else {
srcIPs[i] = generateRandomIPs(1)[0]
dstIPs[i] = generateRandomIPs(1)[0]
}
}
// Create packet patterns for a complete HTTP-like short connection:
// 1. Initial handshake (SYN, SYN-ACK, ACK)
// 2. HTTP Request (PSH+ACK from client)
// 3. HTTP Response (PSH+ACK from server)
// 4. Connection teardown (FIN+ACK, ACK, FIN+ACK, ACK)
type connPackets struct {
syn []byte
synAck []byte
ack []byte
request []byte
response []byte
finClient []byte
ackServer []byte
finServer []byte
ackClient []byte
}
// Generate all possible connection patterns
patterns := make([]connPackets, sc.connCount)
for i := 0; i < sc.connCount; i++ {
patterns[i] = connPackets{
// Handshake
syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)),
synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)),
ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
// Data transfer
request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)),
// Connection teardown
finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)),
ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPAck)),
finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)),
ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Each iteration creates a new short-lived connection
connIdx := i % sc.connCount
p := patterns[connIdx]
// Connection establishment
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck, manager.incomingRules)
manager.processOutgoingHooks(p.ack)
// Data transfer
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response, manager.incomingRules)
// Connection teardown
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer, manager.incomingRules)
manager.dropFilter(p.finServer, manager.incomingRules)
manager.processOutgoingHooks(p.ackClient)
}
})
}
}
// BenchmarkParallelLongLivedConnections tests performance with realistic TCP traffic patterns in parallel
func BenchmarkParallelLongLivedConnections(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
// Configure stateful/stateless mode
if !sc.stateful {
b.Setenv("NB_DISABLE_CONNTRACK", "1")
} else {
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
// Generate IPs for connections
srcIPs := make([]net.IP, sc.connCount)
dstIPs := make([]net.IP, sc.connCount)
for i := 0; i < sc.connCount; i++ {
if sc.routed {
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
} else {
srcIPs[i] = generateRandomIPs(1)[0]
dstIPs[i] = generateRandomIPs(1)[0]
}
}
// Create established connections
for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack)
}
// Pre-generate test packets
inPackets := make([][]byte, sc.connCount)
outPackets := make([][]byte, sc.connCount)
for i := 0; i < sc.connCount; i++ {
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
// Each goroutine gets its own counter to distribute load
counter := 0
for pb.Next() {
connIdx := counter % sc.connCount
counter++
// Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx])
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
}
})
})
}
}
// BenchmarkParallelShortLivedConnections tests performance with many short-lived connections in parallel
func BenchmarkParallelShortLivedConnections(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
// Configure stateful/stateless mode
if !sc.stateful {
b.Setenv("NB_DISABLE_CONNTRACK", "1")
} else {
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
// Generate IPs and pre-generate all packet patterns
srcIPs := make([]net.IP, sc.connCount)
dstIPs := make([]net.IP, sc.connCount)
for i := 0; i < sc.connCount; i++ {
if sc.routed {
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
} else {
srcIPs[i] = generateRandomIPs(1)[0]
dstIPs[i] = generateRandomIPs(1)[0]
}
}
type connPackets struct {
syn []byte
synAck []byte
ack []byte
request []byte
response []byte
finClient []byte
ackServer []byte
finServer []byte
ackClient []byte
}
patterns := make([]connPackets, sc.connCount)
for i := 0; i < sc.connCount; i++ {
patterns[i] = connPackets{
syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)),
synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)),
ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)),
finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)),
ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPAck)),
finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)),
ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
}
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
counter := 0
for pb.Next() {
connIdx := counter % sc.connCount
counter++
p := patterns[connIdx]
// Full connection lifecycle
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck, manager.incomingRules)
manager.processOutgoingHooks(p.ack)
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response, manager.incomingRules)
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer, manager.incomingRules)
manager.dropFilter(p.finServer, manager.incomingRules)
manager.processOutgoingHooks(p.ackClient)
}
})
})
}
}
// generateTCPPacketWithFlags creates a TCP packet with specific flags
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
b.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
}
// Set TCP flags
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
return buf.Bytes()
}

View File

@@ -3,7 +3,6 @@ package uspfilter
import (
"fmt"
"net"
"sync"
"testing"
"time"
@@ -12,7 +11,6 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -187,10 +185,10 @@ func TestAddUDPPacketHook(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
require.NoError(t, err)
manager := &Manager{
incomingRules: map[string]RuleSet{},
outgoingRules: map[string]RuleSet{},
}
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
@@ -261,7 +259,7 @@ func TestManagerReset(t *testing.T) {
return
}
err = m.Reset(nil)
err = m.Reset()
if err != nil {
t.Errorf("failed to reset Manager: %v", err)
return
@@ -315,7 +313,7 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to set network layer for checksum: %v", err)
return
}
payload := gopacket.Payload("test")
payload := gopacket.Payload([]byte("test"))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
@@ -327,12 +325,12 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if m.dropFilter(buf.Bytes(), m.outgoingRules) {
if m.dropFilter(buf.Bytes(), m.outgoingRules, false) {
t.Errorf("expected packet to be accepted")
return
}
if err = m.Reset(nil); err != nil {
if err = m.Reset(); err != nil {
t.Errorf("failed to reset Manager: %v", err)
return
}
@@ -350,9 +348,6 @@ func TestRemovePacketHook(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
defer func() {
require.NoError(t, manager.Reset(nil))
}()
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
@@ -389,88 +384,6 @@ func TestRemovePacketHook(t *testing.T) {
}
}
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
defer func() {
require.NoError(t, manager.Reset(nil))
}()
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
return d
},
}
hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
net.ParseIP("100.10.0.100"),
53,
func([]byte) bool {
hookCalled = true
return true
},
)
require.NotEmpty(t, hookID)
// Create test UDP packet
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: net.ParseIP("100.10.0.1"),
DstIP: net.ParseIP("100.10.0.100"),
Protocol: layers.IPProtocolUDP,
}
udp := &layers.UDP{
SrcPort: 51334,
DstPort: 53,
}
err = udp.SetNetworkLayerForChecksum(ipv4)
require.NoError(t, err)
payload := gopacket.Payload("test")
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload)
require.NoError(t, err)
// Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes())
require.True(t, result)
require.True(t, hookCalled)
// Test non-UDP packet is ignored
ipv4.Protocol = layers.IPProtocolTCP
buf = gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes())
require.False(t, result)
}
func TestUSPFilterCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
@@ -483,7 +396,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
time.Sleep(time.Second)
defer func() {
if err := manager.Reset(nil); err != nil {
if err := manager.Reset(); err != nil {
t.Errorf("clear the manager state: %v", err)
}
time.Sleep(time.Second)
@@ -505,213 +418,3 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
})
}
}
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
return d
},
}
defer func() {
require.NoError(t, manager.Reset(nil))
}()
// Set up packet parameters
srcIP := net.ParseIP("100.10.0.1")
dstIP := net.ParseIP("100.10.0.100")
srcPort := uint16(51334)
dstPort := uint16(53)
// Create outbound packet
outboundIPv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: layers.IPProtocolUDP,
}
outboundUDP := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4)
require.NoError(t, err)
outboundBuf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(outboundBuf, opts,
outboundIPv4,
outboundUDP,
gopacket.Payload("test"),
)
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes())
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
// Create valid inbound response packet
inboundIPv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: dstIP, // Original destination is now source
DstIP: srcIP, // Original source is now destination
Protocol: layers.IPProtocolUDP,
}
inboundUDP := &layers.UDP{
SrcPort: layers.UDPPort(dstPort), // Original destination port is now source
DstPort: layers.UDPPort(srcPort), // Original source port is now destination
}
err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4)
require.NoError(t, err)
inboundBuf := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(inboundBuf, opts,
inboundIPv4,
inboundUDP,
gopacket.Payload("response"),
)
require.NoError(t, err)
// Test roundtrip response handling over time
checkPoints := []struct {
sleep time.Duration
shouldAllow bool
description string
}{
{
sleep: 0,
shouldAllow: true,
description: "Immediate response should be allowed",
},
{
sleep: 50 * time.Millisecond,
shouldAllow: true,
description: "Response within timeout should be allowed",
},
{
sleep: 100 * time.Millisecond,
shouldAllow: true,
description: "Response at half timeout should be allowed",
},
{
// tracker hasn't updated conn for 250ms -> greater than 200ms timeout
sleep: 250 * time.Millisecond,
shouldAllow: false,
description: "Response after timeout should be dropped",
},
}
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
if cp.shouldAllow {
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should still exist during valid window")
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
"LastSeen should be updated for valid responses")
}
}
// Test invalid response packets (while connection is expired)
invalidCases := []struct {
name string
modifyFunc func(*layers.IPv4, *layers.UDP)
description string
}{
{
name: "wrong source IP",
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
ip.SrcIP = net.ParseIP("100.10.0.101")
},
description: "Response from wrong IP should be dropped",
},
{
name: "wrong destination IP",
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
ip.DstIP = net.ParseIP("100.10.0.2")
},
description: "Response to wrong IP should be dropped",
},
{
name: "wrong source port",
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
udp.SrcPort = 54
},
description: "Response from wrong port should be dropped",
},
{
name: "wrong destination port",
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
udp.DstPort = 51335
},
description: "Response to wrong port should be dropped",
},
}
// Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
t.Run(tc.name, func(t *testing.T) {
testIPv4 := *inboundIPv4
testUDP := *inboundUDP
tc.modifyFunc(&testIPv4, &testUDP)
err = testUDP.SetNetworkLayerForChecksum(&testIPv4)
require.NoError(t, err)
testBuf := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(testBuf, opts,
&testIPv4,
&testUDP,
gopacket.Payload("response"),
)
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
require.True(t, drop, tc.description)
})
}
}

142
client/iface/bind/bind.go Normal file
View File

@@ -0,0 +1,142 @@
package bind
import (
"fmt"
"net"
"runtime"
"sync"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
}
type ICEBind struct {
*wgConn.StdNetBind
muUDPMux sync.Mutex
transportNet transport.Net
udpMux *UniversalUDPMuxDefault
filterFn FilterFn
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
ib := &ICEBind{
transportNet: transportNet,
filterFn: filterFn,
}
rc := receiverCreator{
ib,
}
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
sizes[i] = 0
} else {
sizes[i] = msg.N
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
continue
}
msg, err := s.parseSTUNMessage(buffers[i][:n])
if err != nil {
buffers[i] = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}

View File

@@ -1,12 +0,0 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -1,5 +0,0 @@
package bind
import wgConn "golang.zx2c4.com/wireguard/conn"
type Endpoint = wgConn.StdNetEndpoint

View File

@@ -1,303 +0,0 @@
package bind
import (
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:
// 1. filter out STUN messages and handle them
// 2. forward the received packets to the WireGuard interface from the relayed connection
//
// ICEBind.endpoints var is a map that stores the connection for each relayed peer. Fake address is just an IP address
// without port, in the format of 127.1.x.x where x.x is the last two octets of the peer address. We try to avoid to
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct {
*wgConn.StdNetBind
RecvChan chan RecvMessage
transportNet transport.Net
filterFn FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
}
rc := receiverCreator{
ib,
}
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false
s.closedChanMu.Lock()
s.closedChan = make(chan struct{})
s.closedChanMu.Unlock()
fns, port, err := s.StdNetBind.Open(uport)
if err != nil {
return nil, 0, err
}
fns = append(fns, s.receiveRelayed)
return fns, port, nil
}
func (s *ICEBind) Close() error {
if s.closed {
return nil
}
s.closed = true
close(s.closedChan)
return s.StdNetBind.Close()
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
}
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock()
b.endpoints[fakeAddr] = conn
b.endpointsMu.Unlock()
return fakeUDPAddr, nil
}
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock()
defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeAddr)
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstIP()]
b.endpointsMu.Unlock()
if !ok {
return b.StdNetBind.Send(bufs, ep)
}
for _, buf := range bufs {
if _, err := conn.Write(buf); err != nil {
return err
}
}
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := getMessages(msgsPool)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer putMessages(msgs, msgsPool)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
continue
}
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
continue
}
msg, err := s.parseSTUNMessage(buffers[i][:n])
if err != nil {
buffers[i] = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}
// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the
// WireGuard. Critical part is do not block if the Closed() has been called.
func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
c.closedChanMu.RLock()
defer c.closedChanMu.RUnlock()
select {
case <-c.closedChan:
return 0, net.ErrClosed
case msg, ok := <-c.RecvChan:
if !ok {
return 0, net.ErrClosed
}
copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint)
return 1, nil
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message)
}
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
msgsPool.Put(msgs)
}

View File

@@ -162,13 +162,12 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
}

View File

@@ -5,6 +5,7 @@ package device
import (
"strings"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device"
@@ -30,13 +31,13 @@ type WGTunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: iceBind,
iceBind: bind.NewICEBind(transportNet, filterFn),
tunAdapter: tunAdapter,
}
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"os/exec"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
@@ -28,14 +29,14 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: iceBind,
iceBind: bind.NewICEBind(transportNet, filterFn),
}
}

View File

@@ -6,6 +6,7 @@ package device
import (
"os"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device"
@@ -29,13 +30,13 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
iceBind: iceBind,
iceBind: bind.NewICEBind(transportNet, filterFn),
tunFd: tunFd,
}
}

View File

@@ -6,6 +6,7 @@ package device
import (
"fmt"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
@@ -30,7 +31,7 @@ type TunNetstackDevice struct {
configurer WGConfigurer
}
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
@@ -38,7 +39,7 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m
key: key,
mtu: mtu,
listenAddress: listenAddress,
iceBind: iceBind,
iceBind: bind.NewICEBind(transportNet, filterFn),
}
}

View File

@@ -7,6 +7,7 @@ import (
"os"
"runtime"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
@@ -29,7 +30,7 @@ type USPDevice struct {
configurer WGConfigurer
}
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice {
log.Infof("using userspace bind mode")
checkUser()
@@ -40,8 +41,7 @@ func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int,
port: port,
key: key,
mtu: mtu,
iceBind: iceBind,
}
iceBind: bind.NewICEBind(transportNet, filterFn)}
}
func (t *USPDevice) Create() (WGConfigurer, error) {

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net/netip"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/device"
@@ -31,14 +32,14 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: iceBind,
iceBind: bind.NewICEBind(transportNet, filterFn),
}
}

View File

@@ -27,14 +27,14 @@ import (
type status int
const (
unknown status = 1
unloaded status = 2
unloading status = 3
loading status = 4
live status = 5
inuse status = 6
defaultModuleDir = "/lib/modules"
envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED"
defaultModuleDir = "/lib/modules"
unknown status = iota
unloaded
unloading
loading
live
inuse
envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED"
)
type module struct {

View File

@@ -6,16 +6,12 @@ import (
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
const (
@@ -26,35 +22,14 @@ const (
type WGAddress = device.WGAddress
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
}
type WGIFaceOpts struct {
IFaceName string
Address string
WGPort int
WGPrivKey string
MTU int
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn bind.FilterFn
}
// WGIface represents an interface instance
type WGIface struct {
tun WGTunDevice
userspaceBind bool
mu sync.Mutex
configurer device.WGConfigurer
filter device.PacketFilter
wgProxyFactory wgProxyFactory
}
func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
configurer device.WGConfigurer
filter device.PacketFilter
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
@@ -149,26 +124,22 @@ func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
var result *multierror.Error
if err := w.wgProxyFactory.Free(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
err := w.tun.Close()
if err != nil {
return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)
}
if err := w.tun.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}
if err := w.waitUntilRemoved(); err != nil {
err = w.waitUntilRemoved()
if err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
if err := w.Destroy(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err))
return errors.FormatErrorOrNil(result)
err = w.Destroy()
if err != nil {
return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)
}
log.Infof("interface %s successfully removed", w.Name())
}
return errors.FormatErrorOrNil(result)
return nil
}
// SetFilter sets packet filters for the userspace implementation

View File

@@ -0,0 +1,43 @@
package iface
import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
userspaceBind: true,
}
return wgIFace, nil
}
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
cfgr, err := w.tun.Create(routes, dns, searchDomains)
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
// Create this function make sense on mobile only
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -2,8 +2,6 @@
package iface
import "fmt"
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
// this function is different on Android
@@ -19,8 +17,3 @@ func (w *WGIface) Create() error {
w.configurer = cfgr
return nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on non mobile")
}

View File

@@ -1,24 +0,0 @@
package iface
import (
"fmt"
)
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
cfgr, err := w.tun.Create(routes, dns, searchDomains)
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
// Create this function make sense on mobile only
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -7,8 +7,39 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
userspaceBind: true,
}
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
// this function is different on Android
@@ -34,8 +65,3 @@ func (w *WGIface) Create() error {
return backoff.Retry(operation, backOff)
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -1,10 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/device"
)
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}

31
client/iface/iface_ios.go Normal file
View File

@@ -0,0 +1,31 @@
//go:build ios
package iface
import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
userspaceBind: true,
}
return wgIFace, nil
}
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -1,24 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -1,34 +0,0 @@
//go:build !ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -1,26 +0,0 @@
//go:build ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -1,45 +0,0 @@
//go:build (linux && !android) || freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -1,32 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -45,16 +45,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
t.Fatal(err)
}
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: addr,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -127,16 +118,7 @@ func Test_CreateInterface(t *testing.T) {
if err != nil {
t.Fatal(err)
}
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: 33100,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -171,16 +153,7 @@ func Test_Close(t *testing.T) {
t.Fatal(err)
}
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -216,16 +189,7 @@ func TestRecreation(t *testing.T) {
t.Fatal(err)
}
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -288,15 +252,7 @@ func Test_ConfigureInterface(t *testing.T) {
if err != nil {
t.Fatal(err)
}
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -344,16 +300,7 @@ func Test_UpdatePeer(t *testing.T) {
t.Fatal(err)
}
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: 33100,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -414,16 +361,7 @@ func Test_RemovePeer(t *testing.T) {
t.Fatal(err)
}
opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: 33100,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -480,15 +418,7 @@ func Test_ConnectPeers(t *testing.T) {
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
optsPeer1 := WGIFaceOpts{
IFaceName: peer1ifaceName,
Address: peer1wgIP,
WGPort: peer1wgPort,
WGPrivKey: peer1Key.String(),
MTU: DefaultMTU,
TransportNet: newNet,
}
iface1, err := NewWGIFace(optsPeer1)
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -502,12 +432,7 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
localIP, err := getLocalIP()
if err != nil {
t.Fatal(err)
}
peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer1wgPort))
peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer1wgPort))
if err != nil {
t.Fatal(err)
}
@@ -519,17 +444,7 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil {
t.Fatal(err)
}
optsPeer2 := WGIFaceOpts{
IFaceName: peer2ifaceName,
Address: peer2wgIP,
WGPort: peer2wgPort,
WGPrivKey: peer2Key.String(),
MTU: DefaultMTU,
TransportNet: newNet,
}
iface2, err := NewWGIFace(optsPeer2)
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -543,7 +458,7 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer2wgPort))
peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer2wgPort))
if err != nil {
t.Fatal(err)
}
@@ -612,28 +527,3 @@ func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
}
return wgtypes.Peer{}, fmt.Errorf("peer not found")
}
func getLocalIP() (string, error) {
// Get all interfaces
addrs, err := net.InterfaceAddrs()
if err != nil {
return "", err
}
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok {
continue
}
if ipNet.IP.IsLoopback() {
continue
}
if ipNet.IP.To4() == nil {
continue
}
return ipNet.IP.String(), nil
}
return "", fmt.Errorf("no local IP found")
}

View File

@@ -0,0 +1,49 @@
//go:build (linux && !android) || freebsd
package iface
import (
"fmt"
"runtime"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
// move the kernel/usp/netstack preference evaluation to upper layer
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.userspaceBind = true
return wgIFace, nil
}
if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
wgIFace.userspaceBind = false
return wgIFace, nil
}
if !device.ModuleTunIsLoaded() {
return nil, fmt.Errorf("couldn't check or load tun module")
}
wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
wgIFace.userspaceBind = true
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("CreateOnAndroid function has not implemented on %s platform", runtime.GOOS)
}

View File

@@ -0,0 +1,41 @@
package iface
import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
userspaceBind: true,
}
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on non mobile")
}
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}

View File

@@ -1,141 +0,0 @@
package bind
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
)
type ProxyBind struct {
Bind *bind.ICEBind
wgAddr *net.UDPAddr
wgEndpoint *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
}
// AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration.
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
if err != nil {
return err
}
p.wgAddr = addr
p.wgEndpoint = addrToEndpoint(addr)
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
return err
}
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return p.wgAddr
}
func (p *ProxyBind) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
// Start the proxy only once
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
}
func (p *ProxyBind) Pause() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
}
func (p *ProxyBind) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
return p.close()
}
func (p *ProxyBind) close() error {
p.closeMu.Lock()
defer p.closeMu.Unlock()
if p.closed {
return nil
}
p.closed = true
p.cancel()
p.Bind.RemoveEndpoint(p.wgAddr)
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
}
return nil
}
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("failed to close remote conn: %s", err)
}
}()
for {
buf := make([]byte, 1500)
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
msg := bind.RecvMessage{
Endpoint: p.wgEndpoint,
Buffer: buf[:n],
}
p.Bind.RecvChan <- msg
p.pausedMu.Unlock()
}
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}

View File

@@ -1,49 +0,0 @@
//go:build linux && !android
package wgproxy
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
type KernelFactory struct {
wgPort int
ebpfProxy *ebpf.WGEBPFProxy
}
func NewKernelFactory(wgPort int) *KernelFactory {
f := &KernelFactory{
wgPort: wgPort,
}
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f
}
log.Infof("WireGuard Proxy Factory will produce eBPF proxy")
f.ebpfProxy = ebpfProxy
return f
}
func (w *KernelFactory) GetProxy() Proxy {
if w.ebpfProxy == nil {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
return &ebpf.ProxyWrapper{
WgeBPFProxy: w.ebpfProxy,
}
}
func (w *KernelFactory) Free() error {
if w.ebpfProxy == nil {
return nil
}
return w.ebpfProxy.Free()
}

View File

@@ -1,29 +0,0 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
}
func NewKernelFactory(wgPort int) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
func (w *KernelFactory) Free() error {
return nil
}

View File

@@ -1,30 +0,0 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
type USPFactory struct {
bind *bind.ICEBind
}
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
log.Infof("WireGuard Proxy Factory will produce bind proxy")
f := &USPFactory{
bind: iceBind,
}
return f
}
func (w *USPFactory) GetProxy() Proxy {
return &proxyBind.ProxyBind{
Bind: w.bind,
}
}
func (w *USPFactory) Free() error {
return nil
}

View File

@@ -1,15 +0,0 @@
package wgproxy
import (
"context"
"net"
)
// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
CloseConn() error
}

View File

@@ -1,56 +0,0 @@
//go:build linux && !android
package wgproxy
import (
"context"
"os"
"testing"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
)
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
if os.Getenv("GITHUB_ACTIONS") != "true" {
t.Skip("Skipping test as it requires root privileges")
}
ctx := context.Background()
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
tests := []struct {
name string
proxy Proxy
}{
{
name: "ebpf proxy",
proxy: &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}

View File

@@ -1,11 +1,8 @@
package id
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/netip"
"strconv"
"github.com/netbirdio/netbird/client/firewall/manager"
)
@@ -24,41 +21,5 @@ func GenerateRouteRuleKey(
dPort *manager.Port,
action manager.Action,
) RuleID {
manager.SortPrefixes(sources)
h := sha256.New()
// Write all fields to the hasher, with delimiters
h.Write([]byte("sources:"))
for _, src := range sources {
h.Write([]byte(src.String()))
h.Write([]byte(","))
}
h.Write([]byte("destination:"))
h.Write([]byte(destination.String()))
h.Write([]byte("proto:"))
h.Write([]byte(proto))
h.Write([]byte("sPort:"))
if sPort != nil {
h.Write([]byte(sPort.String()))
} else {
h.Write([]byte("<nil>"))
}
h.Write([]byte("dPort:"))
if dPort != nil {
h.Write([]byte(dPort.String()))
} else {
h.Write([]byte("<nil>"))
}
h.Write([]byte("action:"))
h.Write([]byte(strconv.Itoa(int(action))))
hash := hex.EncodeToString(h.Sum(nil))
// prepend destination prefix to be able to identify the rule
return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16]))
return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action))
}

View File

@@ -3,7 +3,6 @@ package acl
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"net"
"net/netip"
@@ -11,18 +10,14 @@ import (
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
var ErrSourceRangesEmpty = errors.New("sources range is empty")
// Manager is a ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap)
@@ -172,40 +167,31 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
}
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present
var newRouteRules = make(map[id.RuleID]struct{})
for _, rule := range rules {
id, err := d.applyRouteACL(rule)
if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
} else {
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
}
continue
return fmt.Errorf("apply route ACL: %w", err)
}
newRouteRules[id] = struct{}{}
}
// Clean up old firewall rules
for id := range d.routeRules {
if _, exists := newRouteRules[id]; !exists {
if _, ok := newRouteRules[id]; !ok {
if err := d.firewall.DeleteRouteRule(id); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
log.Errorf("failed to delete route firewall rule: %v", err)
continue
}
// implicitly deleted from the map
delete(d.routeRules, id)
}
}
d.routeRules = newRouteRules
return nberrors.FormatErrorOrNil(merr)
return nil
}
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 {
return "", ErrSourceRangesEmpty
return "", fmt.Errorf("source ranges is empty")
}
var sources []netip.Prefix

View File

@@ -1,6 +1,7 @@
package acl
import (
"context"
"net"
"testing"
@@ -51,13 +52,13 @@ func TestDefaultManager(t *testing.T) {
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil)
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
if err != nil {
t.Errorf("create firewall: %v", err)
return
}
defer func(fw manager.Manager) {
_ = fw.Reset(nil)
_ = fw.Reset()
}(fw)
acl := NewDefaultManager(fw)
@@ -344,13 +345,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil)
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
if err != nil {
t.Errorf("create firewall: %v", err)
return
}
defer func(fw manager.Manager) {
_ = fw.Reset(nil)
_ = fw.Reset()
}(fw)
acl := NewDefaultManager(fw)

View File

@@ -46,7 +46,6 @@ type ConfigInput struct {
ManagementURL string
AdminURL string
ConfigPath string
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
NATExternalIPs []string
@@ -106,10 +105,10 @@ type Config struct {
// DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration
// Path to a certificate used for mTLS authentication
//Path to a certificate used for mTLS authentication
ClientCertPath string
// Path to corresponding private key of ClientCertPath
//Path to corresponding private key of ClientCertPath
ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"`
@@ -117,7 +116,7 @@ type Config struct {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) {
if fileExists(configPath) {
if configFileIsExists(configPath) {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
@@ -150,7 +149,7 @@ func ReadConfig(configPath string) (*Config, error) {
// UpdateConfig update existing configuration according to input configuration and return with the configuration
func UpdateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) {
if !configFileIsExists(input.ConfigPath) {
return nil, status.Errorf(codes.NotFound, "config file doesn't exist")
}
@@ -159,13 +158,13 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
// UpdateOrCreateConfig reads existing config or generates a new one
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) {
if !configFileIsExists(input.ConfigPath) {
log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input)
if err != nil {
return nil, err
}
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
return cfg, err
}
@@ -186,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
// WriteOutConfig write put the prepared config to the given path
func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(context.Background(), path, config)
return util.WriteJson(path, config)
}
// createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -216,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
}
if updated {
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
if err := util.WriteJson(input.ConfigPath, config); err != nil {
return nil, err
}
}
@@ -473,19 +472,11 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
return false
}
func fileExists(path string) bool {
func configFileIsExists(path string) bool {
_, err := os.Stat(path)
return !os.IsNotExist(err)
}
func createFile(path string) error {
file, err := os.Create(path)
if err != nil {
return err
}
return file.Close()
}
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
// The check is performed only for the NetBird's managed version.

View File

@@ -40,8 +40,6 @@ type ConnectClient struct {
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
persistNetworkMap bool
}
func NewConnectClient(
@@ -64,7 +62,10 @@ func (c *ConnectClient) Run() error {
}
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
func (c *ConnectClient) RunWithProbes(
probes *ProbeHolder,
runningChan chan error,
) error {
return c.run(MobileDependency{}, probes, runningChan)
}
@@ -91,7 +92,6 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
stateFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
debug.SetGCPercent(5)
@@ -100,12 +100,15 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, nil)
}
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
func (c *ConnectClient) run(
mobileDependency MobileDependency,
probes *ProbeHolder,
runningChan chan error,
) error {
defer func() {
if r := recover(); r != nil {
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
@@ -114,6 +117,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
// Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil {
log.Errorf("checking unclean shutdown error: %s", err)
}
backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second,
RandomizationFactor: 1,
@@ -161,8 +170,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
engineCtx, cancel := context.WithCancel(c.ctx)
defer func() {
_, err := state.Status()
c.statusRecorder.MarkManagementDisconnected(err)
c.statusRecorder.MarkManagementDisconnected(state.err)
c.statusRecorder.CleanLocalPeerState()
cancel()
}()
@@ -212,8 +220,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
c.statusRecorder.MarkSignalDisconnected(nil)
defer func() {
_, err := state.Status()
c.statusRecorder.MarkSignalDisconnected(err)
c.statusRecorder.MarkSignalDisconnected(state.err)
}()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
@@ -236,7 +243,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
relayURLs, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
c.statusRecorder.SetRelayMgr(relayManager)
if len(relayURLs) > 0 {
if token != nil {
if err := relayManager.UpdateToken(token); err != nil {
@@ -247,7 +253,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
if err = relayManager.Serve(); err != nil {
log.Error(err)
return wrapErr(err)
}
c.statusRecorder.SetRelayMgr(relayManager)
}
peerConfig := loginResp.GetPeerConfig()
@@ -262,7 +270,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
c.engineMutex.Unlock()
if err := c.engine.Start(); err != nil {
@@ -340,19 +348,6 @@ func (c *ConnectClient) Engine() *Engine {
return e
}
// Status returns the current client status
func (c *ConnectClient) Status() StatusType {
if c == nil {
return StatusIdle
}
status, err := CtxGetState(c.ctx).Status()
if err != nil {
return StatusIdle
}
return status
}
func (c *ConnectClient) Stop() error {
if c == nil {
return nil
@@ -363,11 +358,7 @@ func (c *ConnectClient) Stop() error {
if c.engine == nil {
return nil
}
if err := c.engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
return nil
return c.engine.Stop()
}
func (c *ConnectClient) isContextCancelled() bool {
@@ -379,22 +370,6 @@ func (c *ConnectClient) isContextCancelled() bool {
}
}
// SetNetworkMapPersistence enables or disables network map persistence.
// When enabled, the last received network map will be stored and can be retrieved
// through the Engine's getLatestNetworkMap method. When disabled, any stored
// network map will be cleared. This functionality is primarily used for debugging
// and should not be enabled during normal operation.
func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
c.engineMutex.Lock()
c.persistNetworkMap = enabled
c.engineMutex.Unlock()
engine := c.Engine()
if engine != nil {
engine.SetNetworkMapPersistence(enabled)
}
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false

View File

@@ -1,5 +1,6 @@
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
)

View File

@@ -3,5 +3,6 @@
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
)

View File

@@ -9,8 +9,6 @@ import (
"github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
var (
@@ -22,7 +20,7 @@ var (
}
)
type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error
type repairConfFn func([]string, string, *resolvConf) error
type repair struct {
operationFile string
@@ -42,7 +40,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair {
}
}
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) {
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) {
if f.inotify != nil {
return
}
@@ -83,7 +81,7 @@ func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP strin
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
}
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf, stateManager)
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf)
if err != nil {
log.Errorf("failed to repair resolv.conf: %v", err)
}

View File

@@ -9,7 +9,6 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/util"
)
@@ -105,14 +104,14 @@ nameserver 8.8.8.8`,
var changed bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
updateFn := func([]string, string, *resolvConf) error {
changed = true
cancel()
return nil
}
r := newRepair(operationFile, updateFn)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
if err != nil {
@@ -152,14 +151,14 @@ searchdomain netbird.cloud something`
var changed bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
updateFn := func([]string, string, *resolvConf) error {
changed = true
cancel()
return nil
}
r := newRepair(tmpLink, updateFn)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
if err != nil {

View File

@@ -11,8 +11,6 @@ import (
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
@@ -38,7 +36,7 @@ type fileConfigurator struct {
nbNameserverIP string
}
func newFileConfigurator() (*fileConfigurator, error) {
func newFileConfigurator() (hostManager, error) {
fc := &fileConfigurator{}
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
return fc, nil
@@ -48,7 +46,7 @@ func (f *fileConfigurator) supportCustomPort() bool {
return false
}
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
backupFileExist := f.isBackupFileExist()
if !config.RouteAll {
if backupFileExist {
@@ -78,15 +76,15 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
f.repair.stopWatchFileChanges()
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager)
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
if err != nil {
return err
}
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager)
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
return nil
}
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error {
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error {
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
nameServers := generateNsList(nbNameserverIP, cfg)
@@ -109,7 +107,7 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil {
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil {
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
}
@@ -147,6 +145,10 @@ func (f *fileConfigurator) restore() error {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
}
@@ -174,7 +176,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
return restoreResolvConfFile()
}
log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress)
log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring")
return nil
}
@@ -190,6 +192,10 @@ func restoreResolvConfFile() error {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
}
return nil
}

View File

@@ -1,226 +0,0 @@
package dns
import (
"slices"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
const (
PriorityDNSRoute = 100
PriorityMatchDomain = 50
PriorityDefault = 0
)
type SubdomainMatcher interface {
dns.Handler
MatchSubdomains() bool
}
type HandlerEntry struct {
Handler dns.Handler
Priority int
Pattern string
OrigPattern string
IsWildcard bool
StopHandler handlerWithStop
MatchSubdomains bool
}
// HandlerChain represents a prioritized chain of DNS handlers
type HandlerChain struct {
mu sync.RWMutex
handlers []HandlerEntry
}
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
type ResponseWriterChain struct {
dns.ResponseWriter
origPattern string
shouldContinue bool
}
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
w.shouldContinue = true
return nil
}
return w.ResponseWriter.WriteMsg(m)
}
func NewHandlerChain() *HandlerChain {
return &HandlerChain{
handlers: make([]HandlerEntry, 0),
}
}
// GetOrigPattern returns the original pattern of the handler that wrote the response
func (w *ResponseWriterChain) GetOrigPattern() string {
return w.origPattern
}
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
c.mu.Lock()
defer c.mu.Unlock()
origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.")
if isWildcard {
pattern = pattern[2:]
}
pattern = dns.Fqdn(pattern)
origPattern = dns.Fqdn(origPattern)
// First remove any existing handler with same original pattern and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break
}
}
// Check if handler implements SubdomainMatcher interface
matchSubdomains := false
if matcher, ok := handler.(SubdomainMatcher); ok {
matchSubdomains = matcher.MatchSubdomains()
}
log.Debugf("adding handler pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
pattern, origPattern, isWildcard, matchSubdomains, priority)
entry := HandlerEntry{
Handler: handler,
Priority: priority,
Pattern: pattern,
OrigPattern: origPattern,
IsWildcard: isWildcard,
StopHandler: stopHandler,
MatchSubdomains: matchSubdomains,
}
// Insert handler in priority order
pos := 0
for i, h := range c.handlers {
if h.Priority < priority {
pos = i
break
}
pos = i + 1
}
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
}
// RemoveHandler removes a handler for the given pattern and priority
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
c.mu.Lock()
defer c.mu.Unlock()
pattern = dns.Fqdn(pattern)
// Find and remove handlers matching both original pattern and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if entry.OrigPattern == pattern && entry.Priority == priority {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return
}
}
}
// HasHandlers returns true if there are any handlers remaining for the given pattern
func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
pattern = dns.Fqdn(pattern)
for _, entry := range c.handlers {
if entry.Pattern == pattern {
return true
}
}
return false
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
qname := r.Question[0].Name
log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock()
handlers := slices.Clone(c.handlers)
c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("current handlers (%d):", len(handlers))
for _, h := range handlers {
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
}
}
// Try handlers in priority order
for _, entry := range handlers {
var matched bool
switch {
case entry.Pattern == ".":
matched = true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
matched = qname == entry.Pattern
}
}
if !matched {
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
continue
}
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler")
continue
}
return
}
// No handler matched or all handlers passed
log.Tracef("no handler found for domain=%s", qname)
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}

View File

@@ -1,511 +0,0 @@
package dns_test
import (
"net"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
)
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
chain := nbdns.NewHandlerChain()
// Create mock handlers for different priorities
defaultHandler := &nbdns.MockHandler{}
matchDomainHandler := &nbdns.MockHandler{}
dnsRouteHandler := &nbdns.MockHandler{}
// Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
// Create test request
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
// Create test writer
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Setup expectations - only highest priority handler should be called
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe()
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe()
// Execute
chain.ServeDNS(w, r)
// Verify all expectations were met
dnsRouteHandler.AssertExpectations(t)
matchDomainHandler.AssertExpectations(t)
defaultHandler.AssertExpectations(t)
}
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
tests := []struct {
name string
handlerDomain string
queryDomain string
isWildcard bool
matchSubdomains bool
shouldMatch bool
}{
{
name: "exact match",
handlerDomain: "example.com.",
queryDomain: "example.com.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "subdomain with non-wildcard and MatchSubdomains true",
handlerDomain: "example.com.",
queryDomain: "sub.example.com.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
{
name: "subdomain with non-wildcard and MatchSubdomains false",
handlerDomain: "example.com.",
queryDomain: "sub.example.com.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard match",
handlerDomain: "*.example.com.",
queryDomain: "sub.example.com.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard no match on apex",
handlerDomain: "*.example.com.",
queryDomain: "example.com.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "root zone match",
handlerDomain: ".",
queryDomain: "anything.com.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "no match different domain",
handlerDomain: "example.com.",
queryDomain: "example.org.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
var handler dns.Handler
if tt.matchSubdomains {
mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
handler = mockSubHandler
if tt.shouldMatch {
mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
}
} else {
mockHandler := &nbdns.MockHandler{}
handler = mockHandler
if tt.shouldMatch {
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
}
}
pattern := tt.handlerDomain
if tt.isWildcard {
pattern = "*." + tt.handlerDomain[2:]
}
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
if h, ok := handler.(*nbdns.MockHandler); ok {
h.AssertExpectations(t)
} else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok {
h.AssertExpectations(t)
}
})
}
}
// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns
func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
tests := []struct {
name string
handlers []struct {
pattern string
priority int
}
queryDomain string
expectedCalls int
expectedHandler int // index of the handler that should be called
}{
{
name: "wildcard and exact same priority - exact should win",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "example.com.", priority: nbdns.PriorityDefault},
},
queryDomain: "example.com.",
expectedCalls: 1,
expectedHandler: 1, // exact match handler should be called
},
{
name: "higher priority wildcard over lower priority exact",
handlers: []struct {
pattern string
priority int
}{
{pattern: "example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "test.example.com.",
expectedCalls: 1,
expectedHandler: 1, // higher priority wildcard handler should be called
},
{
name: "multiple wildcards different priorities",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "test.example.com.",
expectedCalls: 1,
expectedHandler: 2, // highest priority handler should be called
},
{
name: "subdomain with mix of patterns",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "sub.test.example.com.",
expectedCalls: 1,
expectedHandler: 2, // highest priority matching handler should be called
},
{
name: "root zone with specific domain",
handlers: []struct {
pattern string
priority int
}{
{pattern: ".", priority: nbdns.PriorityDefault},
{pattern: "example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "example.com.",
expectedCalls: 1,
expectedHandler: 1, // higher priority specific domain should win over root
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
var handlers []*nbdns.MockHandler
// Setup handlers and expectations
for i := range tt.handlers {
handler := &nbdns.MockHandler{}
handlers = append(handlers, handler)
// Set expectation based on whether this handler should be called
if i == tt.expectedHandler {
handler.On("ServeDNS", mock.Anything, mock.Anything).Once()
} else {
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
}
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
}
// Create and execute request
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
// Verify expectations
for _, handler := range handlers {
handler.AssertExpectations(t)
}
})
}
}
// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality
func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
chain := nbdns.NewHandlerChain()
// Create handlers
handler1 := &nbdns.MockHandler{}
handler2 := &nbdns.MockHandler{}
handler3 := &nbdns.MockHandler{}
// Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
// Create test request
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
// Setup mock responses to simulate chain continuation
handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
// First handler signals continue
w := args.Get(0).(*nbdns.ResponseWriterChain)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true // Signal to continue
assert.NoError(t, w.WriteMsg(resp))
}).Once()
handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
// Second handler signals continue
w := args.Get(0).(*nbdns.ResponseWriterChain)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
assert.NoError(t, w.WriteMsg(resp))
}).Once()
handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
// Last handler responds normally
w := args.Get(0).(*nbdns.ResponseWriterChain)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
// Execute
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
// Verify all handlers were called in order
handler1.AssertExpectations(t)
handler2.AssertExpectations(t)
handler3.AssertExpectations(t)
}
// mockResponseWriter implements dns.ResponseWriter for testing
type mockResponseWriter struct {
mock.Mock
}
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (m *mockResponseWriter) Close() error { return nil }
func (m *mockResponseWriter) TsigStatus() error { return nil }
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
func (m *mockResponseWriter) Hijack() {}
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
tests := []struct {
name string
ops []struct {
action string // "add" or "remove"
pattern string
priority int
}
query string
expectedCalls map[int]bool // map[priority]shouldBeCalled
}{
{
name: "remove high priority keeps lower priority handler",
ops: []struct {
action string
pattern string
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityDNSRoute},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: true,
},
},
{
name: "remove lower priority keeps high priority handler",
ops: []struct {
action string
pattern string
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityMatchDomain},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: true,
nbdns.PriorityMatchDomain: false,
},
},
{
name: "remove all handlers in order",
ops: []struct {
action string
pattern string
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"add", "example.com.", nbdns.PriorityDefault},
{"remove", "example.com.", nbdns.PriorityDNSRoute},
{"remove", "example.com.", nbdns.PriorityMatchDomain},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: false,
nbdns.PriorityDefault: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlers := make(map[int]*nbdns.MockHandler)
// Execute operations
for _, op := range tt.ops {
if op.action == "add" {
handler := &nbdns.MockHandler{}
handlers[op.priority] = handler
chain.AddHandler(op.pattern, handler, op.priority, nil)
} else {
chain.RemoveHandler(op.pattern, op.priority)
}
}
// Create test request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Setup expectations
for priority, handler := range handlers {
if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall {
handler.On("ServeDNS", mock.Anything, r).Once()
} else {
handler.On("ServeDNS", mock.Anything, r).Maybe()
}
}
// Execute request
chain.ServeDNS(w, r)
// Verify expectations
for _, handler := range handlers {
handler.AssertExpectations(t)
}
// Verify handler exists check
for priority, shouldExist := range tt.expectedCalls {
if shouldExist {
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
"Handler chain should have handlers for pattern after removing priority %d", priority)
}
}
})
}
}
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
chain := nbdns.NewHandlerChain()
testDomain := "example.com."
testQuery := "test.example.com."
// Create handlers with MatchSubdomains enabled
routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
// Create test request that will be reused
r := new(dns.Msg)
r.SetQuestion(testQuery, dns.TypeA)
// Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
// Test 1: Initial state with all three handlers
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Highest priority handler (routeHandler) should be called
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r)
routeHandler.AssertExpectations(t)
// Test 2: Remove highest priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Now middle priority handler (matchHandler) should be called
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r)
matchHandler.AssertExpectations(t)
// Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r)
defaultHandler.AssertExpectations(t)
// Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
assert.False(t, chain.HasHandlers(testDomain))
}

View File

@@ -5,14 +5,14 @@ import (
"net/netip"
"strings"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns"
)
type hostManager interface {
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
applyDNSConfig(config HostDNSConfig) error
restoreHostDNS() error
supportCustomPort() bool
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
}
type SystemDNSSettings struct {
@@ -35,15 +35,15 @@ type DomainConfig struct {
}
type mockHostConfigurator struct {
applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error
applyDNSConfigFunc func(config HostDNSConfig) error
restoreHostDNSFunc func() error
supportCustomPortFunc func() bool
restoreUncleanShutdownDNSFunc func(*netip.Addr) error
}
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
if m.applyDNSConfigFunc != nil {
return m.applyDNSConfigFunc(config, stateManager)
return m.applyDNSConfigFunc(config)
}
return fmt.Errorf("method applyDNSSettings is not implemented")
}
@@ -62,9 +62,16 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
return false
}
func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
if m.restoreUncleanShutdownDNSFunc != nil {
return m.restoreUncleanShutdownDNSFunc(storedDNSAddress)
}
return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented")
}
func newNoopHostMocker() hostManager {
return &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil },
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
restoreHostDNSFunc: func() error { return nil },
supportCustomPortFunc: func() bool { return true },
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },

View File

@@ -1,17 +1,15 @@
package dns
import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
import "net/netip"
type androidHostManager struct {
}
func newHostManager() (*androidHostManager, error) {
func newHostManager() (hostManager, error) {
return &androidHostManager{}, nil
}
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error {
return nil
}
@@ -22,3 +20,7 @@ func (a androidHostManager) restoreHostDNS() error {
func (a androidHostManager) supportCustomPort() bool {
return false
}
func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
return nil
}

View File

@@ -8,13 +8,12 @@ import (
"fmt"
"io"
"net"
"net/netip"
"os/exec"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
@@ -38,7 +37,7 @@ type systemConfigurator struct {
systemDNSSettings SystemDNSSettings
}
func newHostManager() (*systemConfigurator, error) {
func newHostManager() (hostManager, error) {
return &systemConfigurator{
createdKeys: make(map[string]struct{}),
}, nil
@@ -48,11 +47,12 @@ func (s *systemConfigurator) supportCustomPort() bool {
return true
}
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
var err error
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
// create a file for unclean shutdown detection
if err := createUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err)
}
var (
@@ -123,6 +123,10 @@ func (s *systemConfigurator) restoreHostDNS() error {
}
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown file: %s", err)
}
return nil
}
@@ -316,7 +320,7 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
return primaryService, router, nil
}
func (s *systemConfigurator) restoreUncleanShutdownDNS() error {
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err)
}

Some files were not shown because too many files have changed in this diff Show More