Compare commits

..

1 Commits

Author SHA1 Message Date
braginini
920fe73096 Add perf test for the combined version
Entire-Checkpoint: a86483ba363a
2026-02-17 12:11:28 +01:00
557 changed files with 8457 additions and 70416 deletions

View File

@@ -1,14 +0,0 @@
blank_issues_enabled: true
contact_links:
- name: Community Support
url: https://forum.netbird.io/
about: Community support forum
- name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues
about: Contact us for support
- name: Client/Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client
about: See our client troubleshooting guide for help addressing common issues
- name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting
about: See our self-host troubleshooting guide for help addressing common issues

View File

@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -409,19 +409,12 @@ jobs:
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: docker login for root user
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
@@ -504,18 +497,15 @@ jobs:
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: docker login for root user
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
@@ -596,18 +586,15 @@ jobs:
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: docker login for root user
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
env:
DOCKER_USER: ${{ secrets.DOCKER_USER }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |

View File

@@ -63,15 +63,10 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- name: Generate test script
run: |
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "${{ github.workspace }}\run-tests.cmd"
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -1,51 +0,0 @@
name: PR Title Check
on:
pull_request:
types: [opened, edited, synchronize, reopened]
jobs:
check-title:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@v7
with:
script: |
const title = context.payload.pull_request.title;
const allowedTags = [
'management',
'client',
'signal',
'proxy',
'relay',
'misc',
'infrastructure',
'self-hosted',
'doc',
];
const pattern = /^\[([^\]]+)\]\s+.+/;
const match = title.match(pattern);
if (!match) {
core.setFailed(
`PR title must start with a tag in brackets.\n` +
`Example: [client] fix something\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
const invalid = tags.filter(t => !allowedTags.includes(t));
if (invalid.length > 0) {
core.setFailed(
`Invalid tag(s): ${invalid.join(', ')}\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
console.log(`Valid PR title tags: [${tags.join(', ')}]`);

View File

@@ -10,7 +10,7 @@ on:
env:
SIGN_PIPE_VER: "v0.1.1"
GORELEASER_VER: "v2.14.3"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -169,14 +169,6 @@ jobs:
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
@@ -194,54 +186,18 @@ jobs:
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
- name: Verify RPM signatures
- name: Tag and push PR images (amd64 only)
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
dnf install -y -q rpm-sign curl >/dev/null 2>&1
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
rpm --import /tmp/rpm-pub.key
echo "=== Verifying RPM signatures ==="
for rpm_file in /dist/*amd64*.rpm; do
[ -f "$rpm_file" ] || continue
echo "--- $(basename $rpm_file) ---"
rpm -K "$rpm_file"
done
'
- name: Clean up GPG key
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: Tag and push images (amd64 only)
if: |
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
run: |
resolve_tags() {
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "pr-${{ github.event.pull_request.number }}"
else
echo "main sha-$(git rev-parse --short HEAD)"
fi
}
tag_and_push() {
local src="$1" img_name tag dst
img_name="${src%%:*}"
for tag in $(resolve_tags); do
dst="${img_name}:${tag}"
echo "Tagging ${src} -> ${dst}"
docker tag "$src" "$dst"
docker push "$dst"
done
}
export -f tag_and_push resolve_tags
PR_TAG="pr-${{ github.event.pull_request.number }}"
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
grep '^ghcr.io/' | while read -r SRC; do
tag_and_push "$SRC"
IMG_NAME="${SRC%%:*}"
DST="${IMG_NAME}:${PR_TAG}"
echo "Tagging ${SRC} -> ${DST}"
docker tag "$SRC" "$DST"
docker push "$DST"
done
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
@@ -309,14 +265,6 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
- name: Install LLVM-MinGW for ARM64 cross-compilation
run: |
cd /tmp
@@ -341,24 +289,6 @@ jobs:
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
dnf install -y -q rpm-sign curl >/dev/null 2>&1
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
rpm --import /tmp/rpm-pub.key
echo "=== Verifying RPM signatures ==="
for rpm_file in /dist/*.rpm; do
[ -f "$rpm_file" ] || continue
echo "--- $(basename $rpm_file) ---"
rpm -K "$rpm_file"
done
'
- name: Clean up GPG key
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:

View File

@@ -61,8 +61,8 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 58720256 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
if [ ${SIZE} -gt 57671680 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
exit 1
fi

2
.gitignore vendored
View File

@@ -33,5 +33,3 @@ infrastructure_files/setup-*.env
vendor/
/netbird
client/netbird-electron/
management/server/types/testdata/comparison/
management/server/types/testdata/*.json

View File

@@ -154,26 +154,6 @@ builds:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-idp-migrate
dir: tools/idp-migrate
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-idp-migrate
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
@@ -186,22 +166,18 @@ archives:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
- id: netbird-idp-migrate
builds:
- netbird-idp-migrate
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
license: BSD-3-Clause
id: netbird_deb
id: netbird-deb
bindir: /usr/bin
builds:
- netbird
formats:
- deb
scripts:
postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh"
@@ -209,19 +185,16 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
license: BSD-3-Clause
id: netbird_rpm
id: netbird-rpm
bindir: /usr/bin
builds:
- netbird
formats:
- rpm
scripts:
postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh"
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
dockers:
- image_templates:
- netbirdio/netbird:{{ .Version }}-amd64
@@ -903,7 +876,7 @@ brews:
uploads:
- name: debian
ids:
- netbird_deb
- netbird-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
@@ -911,7 +884,7 @@ uploads:
- name: yum
ids:
- netbird_rpm
- netbird-rpm
mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com

View File

@@ -61,7 +61,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
homepage: https://netbird.io/
id: netbird_ui_deb
id: netbird-ui-deb
package_name: netbird-ui
builds:
- netbird-ui
@@ -80,7 +80,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
homepage: https://netbird.io/
id: netbird_ui_rpm
id: netbird-ui-rpm
package_name: netbird-ui
builds:
- netbird-ui
@@ -95,14 +95,11 @@ nfpms:
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
uploads:
- name: debian
ids:
- netbird_ui_deb
- netbird-ui-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
@@ -110,7 +107,7 @@ uploads:
- name: yum
ids:
- netbird_ui_rpm
- netbird-ui-rpm
mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com

View File

@@ -1,7 +1,7 @@
## Contributor License Agreement
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany,
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance

View File

@@ -126,7 +126,6 @@ See a complete [architecture overview](https://docs.netbird.io/about-netbird/how
### Community projects
- [NetBird installer script](https://github.com/physk/netbird-installer)
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.3
FROM alpine:3.23.2
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \
@@ -17,7 +17,8 @@ ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -23,7 +23,8 @@ ENV \
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
NB_DISABLE_DNS="true" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -124,7 +124,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
@@ -157,7 +157,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
@@ -205,7 +205,7 @@ func (c *Client) PeersList() *PeerInfoArray {
pi := PeerInfo{
p.IP,
p.FQDN,
int(p.ConnStatus),
p.ConnStatus.String(),
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
}
peerInfos[n] = pi

View File

@@ -2,20 +2,11 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
// Connection status constants exported via gomobile.
const (
ConnStatusIdle = int(peer.StatusIdle)
ConnStatusConnecting = int(peer.StatusConnecting)
ConnStatusConnected = int(peer.StatusConnected)
)
// PeerInfo describe information about the peers. It designed for the UI usage
type PeerInfo struct {
IP string
FQDN string
ConnStatus int
ConnStatus string // Todo replace to enum
Routes PeerRoutes
}

View File

@@ -181,11 +181,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if stateWasDown {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up")
time.Sleep(time.Second * 10)
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("netbird up")
time.Sleep(time.Second * 10)
}
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
@@ -199,13 +198,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.")
}
needsRestoreUp := false
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
} else {
needsRestoreUp = !stateWasDown
cmd.Println("netbird down")
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("netbird down")
time.Sleep(1 * time.Second)
@@ -213,15 +209,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
Enabled: true,
}); err != nil {
cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message())
return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message())
}
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
needsRestoreUp = false
cmd.Println("netbird up")
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("netbird up")
time.Sleep(3 * time.Second)
@@ -267,28 +261,18 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
if needsRestoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up (restored)")
}
}
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird down")
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("netbird down")
}
if !initialLevelTrace {
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message())
} else {
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Printf("Local file:\n%s\n", resp.GetPath())

View File

@@ -1,287 +0,0 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/signal"
"regexp"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
var (
exposePin string
exposePassword string
exposeUserGroups []string
exposeDomain string
exposeNamePrefix string
exposeProtocol string
exposeExternalPort uint16
)
var exposeCmd = &cobra.Command{
Use: "expose <port>",
Short: "Expose a local port via the NetBird reverse proxy",
Args: cobra.ExactArgs(1),
Example: ` netbird expose --with-password safe-pass 8080
netbird expose --protocol tcp 5432
netbird expose --protocol tcp --with-external-port 5433 5432
netbird expose --protocol tls --with-custom-domain tls.example.com 4443`,
RunE: exposeFn,
}
func init() {
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)")
exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)")
}
// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags.
func isClusterProtocol(protocol string) bool {
switch strings.ToLower(protocol) {
case "tcp", "udp", "tls":
return true
default:
return false
}
}
// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP)
// where domain display doesn't apply. TLS uses SNI so it has a domain.
func isPortBasedProtocol(protocol string) bool {
switch strings.ToLower(protocol) {
case "tcp", "udp":
return true
default:
return false
}
}
// extractPort returns the port portion of a URL like "tcp://host:12345", or
// falls back to the given default formatted as a string.
func extractPort(serviceURL string, fallback uint16) string {
u := serviceURL
if idx := strings.Index(u, "://"); idx != -1 {
u = u[idx+3:]
}
if i := strings.LastIndex(u, ":"); i != -1 {
if p := u[i+1:]; p != "" {
return p
}
}
return strconv.FormatUint(uint64(fallback), 10)
}
// resolveExternalPort returns the effective external port, defaulting to the target port.
func resolveExternalPort(targetPort uint64) uint16 {
if exposeExternalPort != 0 {
return exposeExternalPort
}
return uint16(targetPort)
}
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
port, err := strconv.ParseUint(portStr, 10, 32)
if err != nil {
return 0, fmt.Errorf("invalid port number: %s", portStr)
}
if port == 0 || port > 65535 {
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
}
if !isProtocolValid(exposeProtocol) {
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
}
if isClusterProtocol(exposeProtocol) {
if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 {
return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol)
}
} else if cmd.Flags().Changed("with-external-port") {
return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol)
}
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
}
if cmd.Flags().Changed("with-password") && exposePassword == "" {
return 0, fmt.Errorf("password cannot be empty")
}
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
return 0, fmt.Errorf("user groups cannot be empty")
}
return port, nil
}
func isProtocolValid(exposeProtocol string) bool {
switch strings.ToLower(exposeProtocol) {
case "http", "https", "tcp", "udp", "tls":
return true
default:
return false
}
}
func exposeFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
log.Errorf("failed initializing log %v", err)
return err
}
cmd.Root().SilenceUsage = false
port, err := validateExposeFlags(cmd, args[0])
if err != nil {
return err
}
cmd.Root().SilenceUsage = true
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
cancel()
}()
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close daemon connection: %v", err)
}
}()
client := proto.NewDaemonServiceClient(conn)
protocol, err := toExposeProtocol(exposeProtocol)
if err != nil {
return err
}
req := &proto.ExposeServiceRequest{
Port: uint32(port),
Protocol: protocol,
Pin: exposePin,
Password: exposePassword,
UserGroups: exposeUserGroups,
Domain: exposeDomain,
NamePrefix: exposeNamePrefix,
}
if isClusterProtocol(exposeProtocol) {
req.ListenPort = uint32(resolveExternalPort(port))
}
stream, err := client.ExposeService(ctx, req)
if err != nil {
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
}
if err := handleExposeReady(cmd, stream, port); err != nil {
return err
}
return waitForExposeEvents(cmd, ctx, stream)
}
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
p, err := expose.ParseProtocolType(exposeProtocol)
if err != nil {
return 0, fmt.Errorf("invalid protocol: %w", err)
}
switch p {
case expose.ProtocolHTTP:
return proto.ExposeProtocol_EXPOSE_HTTP, nil
case expose.ProtocolHTTPS:
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
case expose.ProtocolTCP:
return proto.ExposeProtocol_EXPOSE_TCP, nil
case expose.ProtocolUDP:
return proto.ExposeProtocol_EXPOSE_UDP, nil
case expose.ProtocolTLS:
return proto.ExposeProtocol_EXPOSE_TLS, nil
default:
return 0, fmt.Errorf("unhandled protocol type: %d", p)
}
}
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
}
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
if !ok {
return fmt.Errorf("unexpected expose event: %T", event.Event)
}
printExposeReady(cmd, ready.Ready, port)
return nil
}
func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) {
cmd.Println("Service exposed successfully!")
cmd.Printf(" Name: %s\n", r.ServiceName)
if r.ServiceUrl != "" {
cmd.Printf(" URL: %s\n", r.ServiceUrl)
}
if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) {
cmd.Printf(" Domain: %s\n", r.Domain)
}
cmd.Printf(" Protocol: %s\n", exposeProtocol)
cmd.Printf(" Internal: %d\n", port)
if isClusterProtocol(exposeProtocol) {
cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port)))
}
if r.PortAutoAssigned && exposeExternalPort != 0 {
cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort)
}
cmd.Println()
cmd.Println("Press Ctrl+C to stop exposing.")
}
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
for {
_, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
cmd.Println("\nService stopped.")
//nolint:nilerr
return nil
}
if errors.Is(err, io.EOF) {
return fmt.Errorf("connection to daemon closed unexpectedly")
}
return fmt.Errorf("stream error: %w", err)
}
}
}

View File

@@ -22,7 +22,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -81,15 +80,6 @@ var (
Short: "",
Long: "",
SilenceUsage: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(cmd.Root())
// Don't resolve for service commands — they create the socket, not connect to it.
if !isServiceCmd(cmd) {
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
}
return nil
},
}
)
@@ -154,7 +144,6 @@ func init() {
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
rootCmd.AddCommand(profileCmd)
rootCmd.AddCommand(exposeCmd)
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
@@ -396,6 +385,7 @@ func migrateToNetbird(oldPath, newPath string) bool {
}
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
@@ -408,13 +398,3 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
return conn, nil
}
// isServiceCmd returns true if cmd is the "service" command or a child of it.
func isServiceCmd(cmd *cobra.Command) bool {
for c := cmd; c != nil; c = c.Parent() {
if c.Name() == "service" {
return true
}
}
return false
}

View File

@@ -41,7 +41,7 @@ func init() {
defaultServiceName = "Netbird"
}
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")

View File

@@ -103,7 +103,7 @@ func (p *program) Stop(srv service.Service) error {
// Common setup for service control commands
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
// rootCmd env vars are already applied by PersistentPreRunE.
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(serviceCmd)
cmd.SetOut(cmd.OutOrStdout())

View File

@@ -119,10 +119,6 @@ var installCmd = &cobra.Command{
return err
}
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}
svcConfig, err := createServiceConfigForInstall()
if err != nil {
return err
@@ -140,10 +136,6 @@ var installCmd = &cobra.Command{
return fmt.Errorf("install service: %w", err)
}
if err := saveServiceParams(currentServiceParams()); err != nil {
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
}
cmd.Println("NetBird service has been installed")
return nil
},
@@ -195,10 +187,6 @@ This command will temporarily stop the service, update its configuration, and re
return err
}
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}
wasRunning, err := isServiceRunning()
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
return fmt.Errorf("check service status: %w", err)
@@ -234,10 +222,6 @@ This command will temporarily stop the service, update its configuration, and re
return fmt.Errorf("install service with new config: %w", err)
}
if err := saveServiceParams(currentServiceParams()); err != nil {
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
}
if wasRunning {
cmd.Println("Starting NetBird service...")
if err := s.Start(); err != nil {

View File

@@ -1,201 +0,0 @@
//go:build !ios && !android
package cmd
import (
"context"
"encoding/json"
"fmt"
"maps"
"os"
"path/filepath"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/util"
)
const serviceParamsFile = "service.json"
// serviceParams holds install-time service parameters that persist across
// uninstall/reinstall cycles. Saved to <stateDir>/service.json.
type serviceParams struct {
LogLevel string `json:"log_level"`
DaemonAddr string `json:"daemon_addr"`
ManagementURL string `json:"management_url,omitempty"`
ConfigPath string `json:"config_path,omitempty"`
LogFiles []string `json:"log_files,omitempty"`
DisableProfiles bool `json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
// serviceParamsPath returns the path to the service params file.
func serviceParamsPath() string {
return filepath.Join(configs.StateDir, serviceParamsFile)
}
// loadServiceParams reads saved service parameters from disk.
// Returns nil with no error if the file does not exist.
func loadServiceParams() (*serviceParams, error) {
path := serviceParamsPath()
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil //nolint:nilnil
}
return nil, fmt.Errorf("read service params %s: %w", path, err)
}
var params serviceParams
if err := json.Unmarshal(data, &params); err != nil {
return nil, fmt.Errorf("parse service params %s: %w", path, err)
}
return &params, nil
}
// saveServiceParams writes current service parameters to disk atomically
// with restricted permissions.
func saveServiceParams(params *serviceParams) error {
path := serviceParamsPath()
if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil {
return fmt.Errorf("save service params: %w", err)
}
return nil
}
// currentServiceParams captures the current state of all package-level
// variables into a serviceParams struct.
func currentServiceParams() *serviceParams {
params := &serviceParams{
LogLevel: logLevel,
DaemonAddr: daemonAddr,
ManagementURL: managementURL,
ConfigPath: configPath,
LogFiles: logFiles,
DisableProfiles: profilesDisabled,
DisableUpdateSettings: updateSettingsDisabled,
}
if len(serviceEnvVars) > 0 {
parsed, err := parseServiceEnvVars(serviceEnvVars)
if err == nil && len(parsed) > 0 {
params.ServiceEnvVars = parsed
}
}
return params
}
// loadAndApplyServiceParams loads saved params from disk and applies them
// to any flags that were not explicitly set.
func loadAndApplyServiceParams(cmd *cobra.Command) error {
params, err := loadServiceParams()
if err != nil {
return err
}
applyServiceParams(cmd, params)
return nil
}
// applyServiceParams merges saved parameters into package-level variables
// for any flag that was not explicitly set by the user (via CLI or env var).
// Flags that were Changed() are left untouched.
func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
if params == nil {
return
}
// For fields with non-empty defaults (log-level, daemon-addr), keep the
// != "" guard so that an older service.json missing the field doesn't
// clobber the default with an empty string.
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
logLevel = params.LogLevel
}
if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" {
daemonAddr = params.DaemonAddr
}
// For optional fields where empty means "use default", always apply so
// that an explicit clear (--management-url "") persists across reinstalls.
if !rootCmd.PersistentFlags().Changed("management-url") {
managementURL = params.ManagementURL
}
if !rootCmd.PersistentFlags().Changed("config") {
configPath = params.ConfigPath
}
if !rootCmd.PersistentFlags().Changed("log-file") {
logFiles = params.LogFiles
}
if !serviceCmd.PersistentFlags().Changed("disable-profiles") {
profilesDisabled = params.DisableProfiles
}
if !serviceCmd.PersistentFlags().Changed("disable-update-settings") {
updateSettingsDisabled = params.DisableUpdateSettings
}
applyServiceEnvParams(cmd, params)
}
// applyServiceEnvParams merges saved service environment variables.
// If --service-env was explicitly set, explicit values win on key conflict
// but saved keys not in the explicit set are carried over.
// If --service-env was not set, saved env vars are used entirely.
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
if len(params.ServiceEnvVars) == 0 {
return
}
if !cmd.Flags().Changed("service-env") {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
return
}
// Explicit env vars were provided: merge saved values underneath.
explicit, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
return
}
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
maps.Copy(merged, params.ServiceEnvVars)
maps.Copy(merged, explicit) // explicit wins on conflict
serviceEnvVars = envMapToSlice(merged)
}
var resetParamsCmd = &cobra.Command{
Use: "reset-params",
Short: "Remove saved service install parameters",
Long: "Removes the saved service.json file so the next install uses default parameters.",
RunE: func(cmd *cobra.Command, args []string) error {
path := serviceParamsPath()
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
cmd.Println("No saved service parameters found")
return nil
}
return fmt.Errorf("remove service params: %w", err)
}
cmd.Printf("Removed saved service parameters (%s)\n", path)
return nil
},
}
// envMapToSlice converts a map of env vars to a KEY=VALUE slice.
func envMapToSlice(m map[string]string) []string {
s := make([]string, 0, len(m))
for k, v := range m {
s = append(s, k+"="+v)
}
return s
}

View File

@@ -1,523 +0,0 @@
//go:build !ios && !android
package cmd
import (
"encoding/json"
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"strings"
"testing"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/configs"
)
func TestServiceParamsPath(t *testing.T) {
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = "/var/lib/netbird"
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
configs.StateDir = "/custom/state"
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
}
func TestSaveAndLoadServiceParams(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
params := &serviceParams{
LogLevel: "debug",
DaemonAddr: "unix:///var/run/netbird.sock",
ManagementURL: "https://my.server.com",
ConfigPath: "/etc/netbird/config.json",
LogFiles: []string{"/var/log/netbird/client.log", "console"},
DisableProfiles: true,
DisableUpdateSettings: false,
ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"},
}
err := saveServiceParams(params)
require.NoError(t, err)
// Verify the file exists and is valid JSON.
data, err := os.ReadFile(filepath.Join(tmpDir, "service.json"))
require.NoError(t, err)
assert.True(t, json.Valid(data))
loaded, err := loadServiceParams()
require.NoError(t, err)
require.NotNil(t, loaded)
assert.Equal(t, params.LogLevel, loaded.LogLevel)
assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr)
assert.Equal(t, params.ManagementURL, loaded.ManagementURL)
assert.Equal(t, params.ConfigPath, loaded.ConfigPath)
assert.Equal(t, params.LogFiles, loaded.LogFiles)
assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles)
assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings)
assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars)
}
func TestLoadServiceParams_FileNotExists(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
params, err := loadServiceParams()
assert.NoError(t, err)
assert.Nil(t, params)
}
func TestLoadServiceParams_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600)
require.NoError(t, err)
params, err := loadServiceParams()
assert.Error(t, err)
assert.Nil(t, params)
}
func TestCurrentServiceParams(t *testing.T) {
origLogLevel := logLevel
origDaemonAddr := daemonAddr
origManagementURL := managementURL
origConfigPath := configPath
origLogFiles := logFiles
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() {
logLevel = origLogLevel
daemonAddr = origDaemonAddr
managementURL = origManagementURL
configPath = origConfigPath
logFiles = origLogFiles
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
serviceEnvVars = origServiceEnvVars
})
logLevel = "trace"
daemonAddr = "tcp://127.0.0.1:9999"
managementURL = "https://mgmt.example.com"
configPath = "/tmp/test-config.json"
logFiles = []string{"/tmp/test.log"}
profilesDisabled = true
updateSettingsDisabled = true
serviceEnvVars = []string{"FOO=bar", "BAZ=qux"}
params := currentServiceParams()
assert.Equal(t, "trace", params.LogLevel)
assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr)
assert.Equal(t, "https://mgmt.example.com", params.ManagementURL)
assert.Equal(t, "/tmp/test-config.json", params.ConfigPath)
assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles)
assert.True(t, params.DisableProfiles)
assert.True(t, params.DisableUpdateSettings)
assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars)
}
func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) {
origLogLevel := logLevel
origDaemonAddr := daemonAddr
origManagementURL := managementURL
origConfigPath := configPath
origLogFiles := logFiles
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() {
logLevel = origLogLevel
daemonAddr = origDaemonAddr
managementURL = origManagementURL
configPath = origConfigPath
logFiles = origLogFiles
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
serviceEnvVars = origServiceEnvVars
})
// Reset all flags to defaults.
logLevel = "info"
daemonAddr = "unix:///var/run/netbird.sock"
managementURL = ""
configPath = "/etc/netbird/config.json"
logFiles = []string{"/var/log/netbird/client.log"}
profilesDisabled = false
updateSettingsDisabled = false
serviceEnvVars = nil
// Reset Changed state on all relevant flags.
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
// Simulate user explicitly setting --log-level via CLI.
logLevel = "warn"
require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn"))
saved := &serviceParams{
LogLevel: "debug",
DaemonAddr: "tcp://127.0.0.1:5555",
ManagementURL: "https://saved.example.com",
ConfigPath: "/saved/config.json",
LogFiles: []string{"/saved/client.log"},
DisableProfiles: true,
DisableUpdateSettings: true,
ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"},
}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
// log-level was Changed, so it should keep "warn", not use saved "debug".
assert.Equal(t, "warn", logLevel)
// All other fields were not Changed, so they should use saved values.
assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr)
assert.Equal(t, "https://saved.example.com", managementURL)
assert.Equal(t, "/saved/config.json", configPath)
assert.Equal(t, []string{"/saved/client.log"}, logFiles)
assert.True(t, profilesDisabled)
assert.True(t, updateSettingsDisabled)
assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars)
}
func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) {
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
t.Cleanup(func() {
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
})
// Simulate current state where booleans are true (e.g. set by previous install).
profilesDisabled = true
updateSettingsDisabled = true
// Reset Changed state so flags appear unset.
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
// Saved params have both as false.
saved := &serviceParams{
DisableProfiles: false,
DisableUpdateSettings: false,
}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
assert.False(t, profilesDisabled, "saved false should override current true")
assert.False(t, updateSettingsDisabled, "saved false should override current true")
}
func TestApplyServiceParams_ClearManagementURL(t *testing.T) {
origManagementURL := managementURL
t.Cleanup(func() { managementURL = origManagementURL })
managementURL = "https://leftover.example.com"
// Simulate saved params where management URL was explicitly cleared.
saved := &serviceParams{
LogLevel: "info",
DaemonAddr: "unix:///var/run/netbird.sock",
// ManagementURL intentionally empty: was cleared with --management-url "".
}
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value")
}
func TestApplyServiceParams_NilParams(t *testing.T) {
origLogLevel := logLevel
t.Cleanup(func() { logLevel = origLogLevel })
logLevel = "info"
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
// Should be a no-op.
applyServiceParams(cmd, nil)
assert.Equal(t, "info", logLevel)
}
func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Set up a command with --service-env marked as Changed.
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit"))
serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"}
saved := &serviceParams{
ServiceEnvVars: map[string]string{
"SAVED": "val",
"OVERLAP": "saved",
},
}
applyServiceEnvParams(cmd, saved)
// Parse result for easier assertion.
result, err := parseServiceEnvVars(serviceEnvVars)
require.NoError(t, err)
assert.Equal(t, "yes", result["EXPLICIT"])
assert.Equal(t, "val", result["SAVED"])
// Explicit wins on conflict.
assert.Equal(t, "explicit", result["OVERLAP"])
}
func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
serviceEnvVars = nil
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
saved := &serviceParams{
ServiceEnvVars: map[string]string{"FROM_SAVED": "val"},
}
applyServiceEnvParams(cmd, saved)
result, err := parseServiceEnvVars(serviceEnvVars)
require.NoError(t, err)
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
}
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
// added to serviceParams but not wired into these functions, this test fails.
func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
require.NoError(t, err)
// Collect all JSON field names from the serviceParams struct.
structFields := extractStructJSONFields(t, file, "serviceParams")
require.NotEmpty(t, structFields, "failed to find serviceParams struct fields")
// Collect field names referenced in currentServiceParams and applyServiceParams.
currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields)
applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields)
// applyServiceEnvParams handles ServiceEnvVars indirectly.
applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields)
for k, v := range applyEnvFields {
applyFields[k] = v
}
for _, field := range structFields {
assert.Contains(t, currentFields, field,
"serviceParams field %q is not captured in currentServiceParams()", field)
assert.Contains(t, applyFields, field,
"serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field)
}
}
// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references
// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because
// it flows through newSVCConfig() EnvVars, not CLI args.
func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
require.NoError(t, err)
structFields := extractStructJSONFields(t, file, "serviceParams")
require.NotEmpty(t, structFields)
installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0)
require.NoError(t, err)
// Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig).
fieldsNotInArgs := map[string]bool{
"ServiceEnvVars": true,
}
buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments")
// Forward: every struct field must appear in buildServiceArguments.
for _, field := range structFields {
if fieldsNotInArgs[field] {
continue
}
globalVar := fieldToGlobalVar(field)
assert.Contains(t, buildFields, globalVar,
"serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar)
}
// Reverse: every service-related global used in buildServiceArguments must
// have a corresponding serviceParams field. This catches a developer adding
// a new flag to buildServiceArguments without adding it to the struct.
globalToField := make(map[string]string, len(structFields))
for _, field := range structFields {
globalToField[fieldToGlobalVar(field)] = field
}
// Identifiers in buildServiceArguments that are not service params
// (builtins, boilerplate, loop variables).
nonParamGlobals := map[string]bool{
"args": true, "append": true, "string": true, "_": true,
"logFile": true, // range variable over logFiles
}
for ref := range buildFields {
if nonParamGlobals[ref] {
continue
}
_, inStruct := globalToField[ref]
assert.True(t, inStruct,
"buildServiceArguments() references global %q which has no corresponding serviceParams field", ref)
}
}
// extractStructJSONFields returns field names from a named struct type.
func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string {
t.Helper()
var fields []string
ast.Inspect(file, func(n ast.Node) bool {
ts, ok := n.(*ast.TypeSpec)
if !ok || ts.Name.Name != structName {
return true
}
st, ok := ts.Type.(*ast.StructType)
if !ok {
return false
}
for _, f := range st.Fields.List {
if len(f.Names) > 0 {
fields = append(fields, f.Names[0].Name)
}
}
return false
})
return fields
}
// extractFuncFieldRefs returns which of the given field names appear inside the
// named function, either as selector expressions (params.FieldName) or as
// composite literal keys (&serviceParams{FieldName: ...}).
func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool {
t.Helper()
fieldSet := make(map[string]bool, len(fields))
for _, f := range fields {
fieldSet[f] = true
}
found := make(map[string]bool)
fn := findFuncDecl(file, funcName)
require.NotNil(t, fn, "function %s not found", funcName)
ast.Inspect(fn.Body, func(n ast.Node) bool {
switch v := n.(type) {
case *ast.SelectorExpr:
if fieldSet[v.Sel.Name] {
found[v.Sel.Name] = true
}
case *ast.KeyValueExpr:
if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] {
found[ident.Name] = true
}
}
return true
})
return found
}
// extractFuncGlobalRefs returns all identifier names referenced in the named function body.
func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool {
t.Helper()
fn := findFuncDecl(file, funcName)
require.NotNil(t, fn, "function %s not found", funcName)
refs := make(map[string]bool)
ast.Inspect(fn.Body, func(n ast.Node) bool {
if ident, ok := n.(*ast.Ident); ok {
refs[ident.Name] = true
}
return true
})
return refs
}
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
for _, decl := range file.Decls {
fn, ok := decl.(*ast.FuncDecl)
if ok && fn.Name.Name == name {
return fn
}
}
return nil
}
// fieldToGlobalVar maps serviceParams field names to the package-level variable
// names used in buildServiceArguments and applyServiceParams.
func fieldToGlobalVar(field string) string {
m := map[string]string{
"LogLevel": "logLevel",
"DaemonAddr": "daemonAddr",
"ManagementURL": "managementURL",
"ConfigPath": "configPath",
"LogFiles": "logFiles",
"DisableProfiles": "profilesDisabled",
"DisableUpdateSettings": "updateSettingsDisabled",
"ServiceEnvVars": "serviceEnvVars",
}
if v, ok := m[field]; ok {
return v
}
// Default: lowercase first letter.
return strings.ToLower(field[:1]) + field[1:]
}
func TestEnvMapToSlice(t *testing.T) {
m := map[string]string{"A": "1", "B": "2"}
s := envMapToSlice(m)
assert.Len(t, s, 2)
assert.Contains(t, s, "A=1")
assert.Contains(t, s, "B=2")
}
func TestEnvMapToSlice_Empty(t *testing.T) {
s := envMapToSlice(map[string]string{})
assert.Empty(t, s)
}

View File

@@ -4,9 +4,7 @@ import (
"context"
"fmt"
"os"
"os/signal"
"runtime"
"syscall"
"testing"
"time"
@@ -15,22 +13,6 @@ import (
"github.com/stretchr/testify/require"
)
// TestMain intercepts when this test binary is run as a daemon subprocess.
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
// "service run ..." arguments. Since the test binary can't handle cobra CLI
// args, it exits immediately, causing daemon -r to respawn rapidly until
// hitting the rate limit and exiting. This makes service restart unreliable.
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
<-sig
return
}
os.Exit(m.Run())
}
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
@@ -97,34 +79,6 @@ func TestServiceLifecycle(t *testing.T) {
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (

View File

@@ -6,7 +6,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (

View File

@@ -28,7 +28,6 @@ var (
ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string
checkFlag string
)
var statusCmd = &cobra.Command{
@@ -50,7 +49,6 @@ func init() {
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
}
func statusFunc(cmd *cobra.Command, args []string) error {
@@ -58,10 +56,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
cmd.SetOut(cmd.OutOrStdout())
if checkFlag != "" {
return runHealthCheck(cmd)
}
err := parseFilters()
if err != nil {
return err
@@ -74,17 +68,15 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx, true, false)
resp, err := getStatus(ctx, false)
if err != nil {
return err
}
status := resp.GetStatus()
needsAuth := status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired)
if needsAuth && !jsonFlag && !yamlFlag {
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) {
cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+
@@ -107,17 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
Anonymize: anonymizeFlag,
DaemonVersion: resp.GetDaemonVersion(),
DaemonStatus: nbstatus.ParseDaemonStatus(status),
StatusFilter: statusFilter,
PrefixNamesFilter: prefixNamesFilter,
PrefixNamesFilterMap: prefixNamesFilterMap,
IPsFilter: ipsFilterMap,
ConnectionTypeFilter: connectionTypeFilter,
ProfileName: profName,
})
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var statusOutputString string
switch {
case detailFlag:
@@ -139,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
@@ -149,7 +131,7 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: fullPeerStatus, ShouldRunProbes: shouldRunProbes})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
@@ -203,83 +185,6 @@ func enableDetailFlagWhenFilterFlag() {
}
}
func runHealthCheck(cmd *cobra.Command) error {
check := strings.ToLower(checkFlag)
switch check {
case "live", "ready", "startup":
default:
return fmt.Errorf("unknown check %q, must be one of: live, ready, startup", checkFlag)
}
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := internal.CtxInitState(cmd.Context())
isStartup := check == "startup"
resp, err := getStatus(ctx, isStartup, false)
if err != nil {
return err
}
switch check {
case "live":
return nil
case "ready":
return checkReadiness(resp)
case "startup":
return checkStartup(resp)
default:
return nil
}
}
func checkReadiness(resp *proto.StatusResponse) error {
daemonStatus := internal.StatusType(resp.GetStatus())
switch daemonStatus {
case internal.StatusIdle, internal.StatusConnecting, internal.StatusConnected:
return nil
case internal.StatusNeedsLogin, internal.StatusLoginFailed, internal.StatusSessionExpired:
return fmt.Errorf("readiness check: daemon status is %s", daemonStatus)
default:
return fmt.Errorf("readiness check: unexpected daemon status %q", daemonStatus)
}
}
func checkStartup(resp *proto.StatusResponse) error {
fullStatus := resp.GetFullStatus()
if fullStatus == nil {
return fmt.Errorf("startup check: no full status available")
}
if !fullStatus.GetManagementState().GetConnected() {
return fmt.Errorf("startup check: management not connected")
}
if !fullStatus.GetSignalState().GetConnected() {
return fmt.Errorf("startup check: signal not connected")
}
var relayCount, relaysConnected int
for _, r := range fullStatus.GetRelays() {
uri := r.GetURI()
if !strings.HasPrefix(uri, "rel://") && !strings.HasPrefix(uri, "rels://") {
continue
}
relayCount++
if r.GetAvailable() {
relaysConnected++
}
}
if relayCount > 0 && relaysConnected == 0 {
return fmt.Errorf("startup check: no relay servers available (0/%d connected)", relayCount)
}
return nil
}
func parseInterfaceIP(interfaceIP string) string {
ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil {

View File

@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))

View File

@@ -11,7 +11,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/installer"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/util"
)

View File

@@ -14,7 +14,6 @@ import (
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
@@ -22,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -33,14 +31,14 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized")
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
const (
// PeerStatusConnected indicates the peer is in connected state.
PeerStatusConnected = peer.StatusConnected
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string
@@ -83,14 +81,6 @@ type Options struct {
BlockInbound bool
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
WireguardPort *int
// MTU is the MTU for the WireGuard interface.
// Valid values are in the range 576..8192 bytes.
// If non-nil, this value overrides any value stored in the config file.
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
// Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams.
MTU *uint16
// DNSLabels defines additional DNS labels configured in the peer.
DNSLabels []string
}
// validateCredentials checks that exactly one credential type is provided
@@ -122,12 +112,6 @@ func New(opts Options) (*Client, error) {
return nil, err
}
if opts.MTU != nil {
if err := iface.ValidateMTU(*opts.MTU); err != nil {
return nil, fmt.Errorf("invalid MTU: %w", err)
}
}
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
@@ -156,14 +140,9 @@ func New(opts Options) (*Client, error) {
}
}
var err error
var parsedLabels domain.List
if parsedLabels, err = domain.FromStringList(opts.DNSLabels); err != nil {
return nil, fmt.Errorf("invalid dns labels: %w", err)
}
t := true
var config *profilemanager.Config
var err error
input := profilemanager.ConfigInput{
ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL,
@@ -172,8 +151,6 @@ func New(opts Options) (*Client, error) {
DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort,
MTU: opts.MTU,
DNSLabels: parsedLabels,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)
@@ -225,7 +202,7 @@ func (c *Client) Start(startCtx context.Context) error {
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
client := internal.NewConnectClient(ctx, c.config, c.recorder)
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
@@ -375,32 +352,6 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL.
// It returns an ExposeSession. Call Wait on the session to keep it alive.
func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
mgr := engine.GetExposeManager()
if mgr == nil {
return nil, fmt.Errorf("expose manager not available")
}
resp, err := mgr.Expose(ctx, req)
if err != nil {
return nil, fmt.Errorf("expose: %w", err)
}
return &ExposeSession{
Domain: resp.Domain,
ServiceName: resp.ServiceName,
ServiceURL: resp.ServiceURL,
mgr: mgr,
}, nil
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()

View File

@@ -1,45 +0,0 @@
package embed
import (
"context"
"errors"
"github.com/netbirdio/netbird/client/internal/expose"
)
const (
// ExposeProtocolHTTP exposes the service as HTTP.
ExposeProtocolHTTP = expose.ProtocolHTTP
// ExposeProtocolHTTPS exposes the service as HTTPS.
ExposeProtocolHTTPS = expose.ProtocolHTTPS
// ExposeProtocolTCP exposes the service as TCP.
ExposeProtocolTCP = expose.ProtocolTCP
// ExposeProtocolUDP exposes the service as UDP.
ExposeProtocolUDP = expose.ProtocolUDP
// ExposeProtocolTLS exposes the service as TLS.
ExposeProtocolTLS = expose.ProtocolTLS
)
// ExposeRequest is a request to expose a local service via the NetBird reverse proxy.
type ExposeRequest = expose.Request
// ExposeProtocolType represents the protocol used for exposing a service.
type ExposeProtocolType = expose.ProtocolType
// ExposeSession represents an active expose session. Use Wait to block until the session ends.
type ExposeSession struct {
Domain string
ServiceName string
ServiceURL string
mgr *expose.Manager
}
// Wait blocks while keeping the expose session alive.
// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session.
func (s *ExposeSession) Wait(ctx context.Context) error {
if s == nil || s.mgr == nil {
return errors.New("expose session is not initialized")
}
return s.mgr.KeepAlive(ctx, s.Domain)
}

View File

@@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"os"
"strconv"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
@@ -36,27 +35,20 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
log.Info("forcing userspace firewall")
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
// Kernel cannot fall back to anything else, need to return error
if !iface.IsUserspaceBind() {
return fm, err
}
// Fall back to the userspace packet filter if native is unavailable
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
return fm, nil
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
@@ -168,17 +160,3 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
func forceUserspaceFirewall() bool {
val := os.Getenv(EnvForceUserspaceFirewall)
if val == "" {
return false
}
force, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
return false
}
return force
}

View File

@@ -7,12 +7,6 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when
// native iptables/nftables is available. This only applies when the WireGuard interface
// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of
// kernel netfilter rules.
const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL"
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string

View File

@@ -23,16 +23,16 @@ type Manager struct {
wgIface iFaceMapper
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *router
rawSupported bool
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *router
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
IsUserspaceBind() bool
}
// Create iptables firewall manager
@@ -63,9 +63,10 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
},
}
stateManager.RegisterState(state)
@@ -83,7 +84,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
if err := m.initNoTrackChain(); err != nil {
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
return fmt.Errorf("init notrack chain: %w", err)
}
// persist early to ensure cleanup of chains
@@ -201,10 +202,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
return nberrors.FormatErrorOrNil(merr)
}
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
}
_, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0},
@@ -282,22 +285,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"
@@ -331,10 +318,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if !m.rawSupported {
return fmt.Errorf("raw table not available")
}
wgPortStr := fmt.Sprintf("%d", wgPort)
proxyPortStr := fmt.Sprintf("%d", proxyPort)
@@ -364,28 +347,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
return nil
}
// AddTProxyRule adds TPROXY redirect rules for the transparent proxy.
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
}
// RemoveTProxyRule removes TPROXY redirect rules by ID.
func (m *Manager) RemoveTProxyRule(ruleID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveTProxyRule(ruleID)
}
// AddUDPInspectionHook is a no-op for iptables (kernel-mode firewall has no userspace packet hooks).
func (m *Manager) AddUDPInspectionHook(_ uint16, _ func([]byte) bool) string { return "" }
// RemoveUDPInspectionHook is a no-op for iptables.
func (m *Manager) RemoveUDPInspectionHook(_ string) {}
func (m *Manager) initNoTrackChain() error {
if err := m.cleanupNoTrackChain(); err != nil {
log.Debugf("cleanup notrack chain: %v", err)
@@ -414,16 +375,12 @@ func (m *Manager) initNoTrackChain() error {
return fmt.Errorf("add prerouting jump rule: %w", err)
}
m.rawSupported = true
return nil
}
func (m *Manager) cleanupNoTrackChain() error {
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
if err != nil {
if !m.rawSupported {
return nil
}
return fmt.Errorf("check chain exists: %w", err)
}
if !exists {
@@ -444,7 +401,6 @@ func (m *Manager) cleanupNoTrackChain() error {
return fmt.Errorf("clear and delete chain: %w", err)
}
m.rawSupported = false
return nil
}

View File

@@ -47,6 +47,8 @@ func (i *iFaceMock) Address() wgaddr.Address {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)

View File

@@ -36,7 +36,6 @@ const (
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
chainNATOutput = "NETBIRD-NAT-OUTPUT"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
@@ -44,7 +43,6 @@ const (
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
jumpNatOutput = "jump-nat-output"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
@@ -89,8 +87,6 @@ type router struct {
stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
tproxyRules []tproxyRuleEntry
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
@@ -391,14 +387,6 @@ func (r *router) cleanUpDefaultForwardRules() error {
}
log.Debug("flushing routing related tables")
// Remove jump rules from built-in chains before deleting custom chains,
// otherwise the chain deletion fails with "device or resource busy".
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
log.Debugf("clean OUTPUT jump rule: %v", err)
}
for _, chainInfo := range []struct {
chain string
table string
@@ -408,7 +396,6 @@ func (r *router) cleanUpDefaultForwardRules() error {
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainNATOutput, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
@@ -983,81 +970,6 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.rules[jumpNatOutput]; exists {
return nil
}
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
}
if !chainExists {
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
}
}
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
if !chainExists {
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
}
}
return fmt.Errorf("add OUTPUT jump rule: %w", err)
}
r.rules[jumpNatOutput] = jumpRule
r.updateState()
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
dnatRule := []string{
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
r.updateState()
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil
@@ -1111,92 +1023,3 @@ func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
func (r *router) destroyIPSet(name string) error {
return ipset.Destroy(name)
}
// AddTProxyRule adds iptables nat PREROUTING REDIRECT rules for transparent proxy interception.
// Traffic from sources on dstPorts arriving on the WG interface is redirected
// to the transparent proxy listener on redirectPort.
func (r *router) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
portStr := fmt.Sprintf("%d", redirectPort)
for _, proto := range []string{"tcp", "udp"} {
srcSpecs := r.buildSourceSpecs(sources)
for _, srcSpec := range srcSpecs {
if len(dstPorts) == 0 {
rule := append(srcSpec,
"-i", r.wgIface.Name(),
"-p", proto,
"-j", "REDIRECT",
"--to-ports", portStr,
)
if err := r.iptablesClient.AppendUnique(tableNat, chainRTRDR, rule...); err != nil {
return fmt.Errorf("add redirect rule %s/%s: %w", ruleID, proto, err)
}
r.tproxyRules = append(r.tproxyRules, tproxyRuleEntry{
ruleID: ruleID,
table: tableNat,
chain: chainRTRDR,
spec: rule,
})
} else {
for _, port := range dstPorts {
rule := append(srcSpec,
"-i", r.wgIface.Name(),
"-p", proto,
"--dport", fmt.Sprintf("%d", port),
"-j", "REDIRECT",
"--to-ports", portStr,
)
if err := r.iptablesClient.AppendUnique(tableNat, chainRTRDR, rule...); err != nil {
return fmt.Errorf("add redirect rule %s/%s/%d: %w", ruleID, proto, port, err)
}
r.tproxyRules = append(r.tproxyRules, tproxyRuleEntry{
ruleID: ruleID,
table: tableNat,
chain: chainRTRDR,
spec: rule,
})
}
}
}
}
return nil
}
// RemoveTProxyRule removes all iptables REDIRECT rules for the given ruleID.
func (r *router) RemoveTProxyRule(ruleID string) error {
var remaining []tproxyRuleEntry
for _, entry := range r.tproxyRules {
if entry.ruleID != ruleID {
remaining = append(remaining, entry)
continue
}
if err := r.iptablesClient.DeleteIfExists(entry.table, entry.chain, entry.spec...); err != nil {
log.Debugf("remove tproxy rule %s: %v", ruleID, err)
}
}
r.tproxyRules = remaining
return nil
}
type tproxyRuleEntry struct {
ruleID string
table string
chain string
spec []string
}
func (r *router) buildSourceSpecs(sources []netip.Prefix) [][]string {
if len(sources) == 0 {
return [][]string{{}} // empty spec = match any source
}
specs := make([][]string, 0, len(sources))
for _, src := range sources {
specs = append(specs, []string{"-s", src.String()})
}
return specs
}

View File

@@ -9,9 +9,10 @@ import (
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
MTU uint16 `json:"mtu"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -22,6 +23,10 @@ func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex

View File

@@ -169,33 +169,9 @@ type Manager interface {
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
// AddTProxyRule adds TPROXY redirect rules for specific source CIDRs and destination ports.
// Traffic from sources on dstPorts is redirected to the transparent proxy on redirectPort.
// Empty dstPorts means redirect all ports.
AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error
// RemoveTProxyRule removes TPROXY redirect rules by ID.
RemoveTProxyRule(ruleID string) error
// AddUDPInspectionHook registers a hook that inspects UDP packets before forwarding.
// The hook receives the raw packet and returns true to drop it.
// Used for QUIC SNI-based blocking. Returns a hook ID for removal.
AddUDPInspectionHook(dstPort uint16, hook func(packet []byte) bool) string
// RemoveUDPInspectionHook removes a previously registered inspection hook.
RemoveUDPInspectionHook(hookID string)
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -40,6 +40,7 @@ func getTableName() string {
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
IsUserspaceBind() bool
}
// Manager of iptables firewall
@@ -94,7 +95,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
if err := m.initNoTrackChains(workTable); err != nil {
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
return fmt.Errorf("init notrack chains: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
@@ -105,9 +106,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -203,10 +205,12 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
return m.router.RemoveNatRule(pair)
}
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
}
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -342,22 +346,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRawOutput = "netbird-raw-out"
chainNameRawPrerouting = "netbird-raw-pre"
@@ -482,28 +470,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
return nil
}
// AddTProxyRule adds TPROXY redirect rules for the transparent proxy.
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
}
// RemoveTProxyRule removes TPROXY redirect rules by ID.
func (m *Manager) RemoveTProxyRule(ruleID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveTProxyRule(ruleID)
}
// AddUDPInspectionHook is a no-op for nftables (kernel-mode firewall has no userspace packet hooks).
func (m *Manager) AddUDPInspectionHook(_ uint16, _ func([]byte) bool) string { return "" }
// RemoveUDPInspectionHook is a no-op for nftables.
func (m *Manager) RemoveUDPInspectionHook(_ string) {}
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawOutput,

View File

@@ -52,6 +52,8 @@ func (i *iFaceMock) Address() wgaddr.Address {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
// just check on the local interface

View File

@@ -36,7 +36,6 @@ const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameNATOutput = "netbird-nat-output"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
@@ -77,7 +76,6 @@ type router struct {
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
mtu uint16
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
@@ -1855,130 +1853,6 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.chains[chainNameNATOutput]; exists {
return nil
}
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
Name: chainNameNATOutput,
Table: r.workTable,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
if err := r.conn.Flush(); err != nil {
delete(r.chains, chainNameNATOutput)
return fmt.Errorf("create NAT output chain: %w", err)
}
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameNATOutput],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,
@@ -2138,227 +2012,3 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any
},
}, nil
}
// AddTProxyRule adds nftables TPROXY redirect rules in the mangle prerouting chain.
// Traffic from sources on dstPorts arriving on the WG interface is redirected to
// the transparent proxy listener on redirectPort.
// Separate rules are created for TCP and UDP protocols.
func (r *router) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
// Use the nat redirect chain for DNAT rules.
// TPROXY doesn't work on WG kernel interfaces (socket assignment silently fails),
// so we use DNAT to 127.0.0.1:proxy_port instead. The proxy reads the original
// destination via SO_ORIGINAL_DST (conntrack).
chain := r.chains[chainNameRoutingRdr]
if chain == nil {
return fmt.Errorf("nat redirect chain not initialized")
}
for _, proto := range []uint8{unix.IPPROTO_TCP, unix.IPPROTO_UDP} {
protoName := "tcp"
if proto == unix.IPPROTO_UDP {
protoName = "udp"
}
ruleKey := fmt.Sprintf("tproxy-%s-%s", ruleID, protoName)
if existing, ok := r.rules[ruleKey]; ok && existing.Handle != 0 {
if err := r.decrementSetCounter(existing); err != nil {
log.Debugf("decrement set counter for %s: %v", ruleKey, err)
}
if err := r.conn.DelRule(existing); err != nil {
log.Debugf("remove existing tproxy rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
}
exprs, err := r.buildRedirectExprs(proto, sources, dstPorts, redirectPort)
if err != nil {
return fmt.Errorf("build redirect exprs for %s: %w", protoName, err)
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: exprs,
UserData: []byte(ruleKey),
})
}
// Accept redirected packets in the ACL input chain. After REDIRECT, the
// destination port becomes the proxy port. Without this rule, the ACL filter
// drops the packet. We match on ct state dnat so only REDIRECT'd connections
// are accepted: direct connections to the proxy port are blocked.
inputAcceptKey := fmt.Sprintf("tproxy-%s-input", ruleID)
if _, ok := r.rules[inputAcceptKey]; !ok {
inputChain := &nftables.Chain{
Name: "netbird-acl-input-rules",
Table: r.workTable,
}
r.rules[inputAcceptKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: inputChain,
Exprs: []expr.Any{
// Only accept connections that were REDIRECT'd (ct status dnat)
&expr.Ct{Register: 1, Key: expr.CtKeySTATUS},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(0x20), // IPS_DST_NAT
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(0),
},
// Accept both TCP and UDP redirected to the proxy port.
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(redirectPort),
},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(inputAcceptKey),
})
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush tproxy rules for %s: %w", ruleID, err)
}
return nil
}
// RemoveTProxyRule removes TPROXY redirect rules by ID (both TCP and UDP variants).
func (r *router) RemoveTProxyRule(ruleID string) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var removed int
for _, suffix := range []string{"tcp", "udp", "input"} {
ruleKey := fmt.Sprintf("tproxy-%s-%s", ruleID, suffix)
rule, ok := r.rules[ruleKey]
if !ok {
continue
}
if rule.Handle == 0 {
delete(r.rules, ruleKey)
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Debugf("decrement set counter for %s: %v", ruleKey, err)
}
if err := r.conn.DelRule(rule); err != nil {
log.Debugf("delete tproxy rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
removed++
}
if removed > 0 {
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush tproxy rule removal for %s: %w", ruleID, err)
}
}
return nil
}
// buildRedirectExprs builds nftables expressions for a REDIRECT rule.
// Matches WG interface ingress, source CIDRs, destination ports, then REDIRECTs to the proxy port.
func (r *router) buildRedirectExprs(proto uint8, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) ([]expr.Any, error) {
var exprs []expr.Any
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname(r.wgIface.Name())},
)
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
)
// Source CIDRs use the named ipset shared with route rules.
if len(sources) > 0 {
srcSet := firewall.NewPrefixSet(sources)
srcExprs, err := r.getIpSet(srcSet, sources, true)
if err != nil {
return nil, fmt.Errorf("get source ipset: %w", err)
}
exprs = append(exprs, srcExprs...)
}
if len(dstPorts) == 1 {
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(dstPorts[0]),
},
)
} else if len(dstPorts) > 1 {
setElements := make([]nftables.SetElement, len(dstPorts))
for i, p := range dstPorts {
setElements[i] = nftables.SetElement{Key: binaryutil.BigEndian.PutUint16(p)}
}
portSet := &nftables.Set{
Table: r.workTable,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeInetService,
}
if err := r.conn.AddSet(portSet, setElements); err != nil {
return nil, fmt.Errorf("create port set: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Lookup{
SourceRegister: 1,
SetName: portSet.Name,
SetID: portSet.ID,
},
)
}
// REDIRECT to local proxy port. Changes the destination to the interface's
// primary address + specified port. Conntrack tracks the original destination,
// readable via SO_ORIGINAL_DST.
exprs = append(exprs,
&expr.Immediate{Register: 1, Data: binaryutil.BigEndian.PutUint16(redirectPort)},
&expr.Redir{
RegisterProtoMin: 1,
},
)
return exprs, nil
}

View File

@@ -8,9 +8,10 @@ import (
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
MTU uint16 `json:"mtu"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -21,6 +22,10 @@ func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}

View File

@@ -140,17 +140,6 @@ type Manager struct {
mtu uint16
mssClampValue uint16
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[packetHook]
tcpHookOut atomic.Pointer[packetHook]
}
// packetHook stores a registered hook for a specific IP:port.
type packetHook struct {
ip netip.Addr
port uint16
fn func([]byte) bool
}
// decoder for packages
@@ -605,8 +594,6 @@ func (m *Manager) resetState() {
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
m.tcpHookOut.Store(nil)
if m.udpTracker != nil {
m.udpTracker.Close()
@@ -641,45 +628,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
}
// AddTProxyRule delegates to the native firewall for TPROXY rules.
// In userspace mode (no native firewall), this is a no-op since the
// forwarder intercepts traffic directly.
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
}
// AddUDPInspectionHook registers a hook for QUIC/UDP inspection via the packet filter.
func (m *Manager) AddUDPInspectionHook(dstPort uint16, hook func(packet []byte) bool) string {
m.SetUDPPacketHook(netip.Addr{}, dstPort, hook)
return "udp-inspection"
}
// RemoveUDPInspectionHook removes a previously registered inspection hook.
func (m *Manager) RemoveUDPInspectionHook(_ string) {
m.SetUDPPacketHook(netip.Addr{}, 0, nil)
}
// RemoveTProxyRule delegates to the native firewall for TPROXY rules.
func (m *Manager) RemoveTProxyRule(ruleID string) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveTProxyRule(ruleID)
}
// IsLocalIP reports whether the given IP belongs to the local machine.
func (m *Manager) IsLocalIP(ip netip.Addr) bool {
return m.localipmanager.IsLocalIP(ip)
}
// GetForwarder returns the userspace packet forwarder, or nil if not initialized.
func (m *Manager) GetForwarder() *forwarder.Forwarder {
return m.forwarder.Load()
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
@@ -765,9 +713,6 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
case layers.LayerTypeTCP:
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
return true
}
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
@@ -950,21 +895,38 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
d.dnatOrigPort = 0
}
// udpHooksDrop checks if any UDP hooks should drop the packet
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
}
m.mutex.RLock()
defer m.mutex.RUnlock()
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
// Check specific destination IP first
if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
// Check IPv4 unspecified address
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
if h.ip == dstIP && h.port == dport {
return h.fn(packetData)
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false
}
@@ -1316,6 +1278,12 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook
if rule.udpHook != nil {
return rule.mgmtId, rule.udpHook(packetData), true
}
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
@@ -1374,30 +1342,65 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return sourceMatched
}
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.udpHookOut.Store(nil)
return
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
r := PeerRule{
id: uuid.New().String(),
ip: ip,
protoLayer: layers.LayerTypeUDP,
dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6,
udpHook: hook,
}
m.udpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
if in {
// Incoming UDP hooks are stored in allow rules map
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
}
m.incomingRules[r.ip][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip] = make(map[string]PeerRule)
}
m.outgoingRules[r.ip][r.id] = r
}
m.mutex.Unlock()
return r.id
}
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.tcpHookOut.Store(nil)
return
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// Check incoming hooks (stored in allow rules)
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
}
m.tcpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
// Check outgoing hooks
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
}
return fmt.Errorf("hook with given id not found")
}
// SetLogLevel sets the log level for the firewall manager

View File

@@ -12,7 +12,6 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
@@ -187,52 +186,81 @@ func TestManagerDeleteRule(t *testing.T) {
}
}
func TestSetUDPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
func TestAddUDPPacketHook(t *testing.T) {
tests := []struct {
name string
in bool
expDir fw.RuleDirection
ip netip.Addr
dPort uint16
hook func([]byte) bool
expectedID string
}{
{
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000,
hook: func([]byte) bool { return true },
},
{
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: netip.MustParseAddr("::1"),
dPort: 9000,
hook: func([]byte) bool { return false },
},
}
var called bool
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
called = true
return true
})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
h := manager.udpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(8000), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
assert.Nil(t, manager.udpHookOut.Load())
}
var addedRule PeerRule
if tt.in {
// Incoming UDP hooks are stored in allow rules map
if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
return
}
for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule
}
} else {
if len(manager.outgoingRules[tt.ip]) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
return
}
for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule
}
}
func TestSetTCPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
var called bool
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
called = true
return true
})
h := manager.tcpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(53), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
assert.Nil(t, manager.tcpHookOut.Load())
if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
if tt.dPort != addedRule.dPort.Values[0] {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
return
}
if layers.LayerTypeUDP != addedRule.protoLayer {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return
}
if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set")
return
}
})
}
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
@@ -502,12 +530,39 @@ func TestRemovePacketHook(t *testing.T) {
require.NoError(t, manager.Close(nil))
}()
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
found = true
break
}
}
}
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
if !found {
t.Fatalf("The hook was not added properly.")
}
// Now remove the packet hook
err = manager.RemovePacketHook(hookID)
if err != nil {
t.Fatalf("Failed to remove hook: %s", err)
}
// Assert the hook is removed by checking it in the manager's outgoing rules
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
}
}
}
func TestProcessOutgoingHooks(t *testing.T) {
@@ -537,7 +592,8 @@ func TestProcessOutgoingHooks(t *testing.T) {
}
hookCalled := false
manager.SetUDPPacketHook(
hookID := manager.AddUDPPacketHook(
false,
netip.MustParseAddr("100.10.0.100"),
53,
func([]byte) bool {
@@ -545,6 +601,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
return true
},
)
require.NotEmpty(t, hookID)
// Create test UDP packet
ipv4 := &layers.IPv4{

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"runtime"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
@@ -22,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/inspect"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -48,10 +46,6 @@ type Forwarder struct {
netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
// proxy is the optional inspection engine.
// When set, TCP connections are handed to the engine for protocol detection
// and rule evaluation. Swapped atomically for lock-free hot-path access.
proxy atomic.Pointer[inspect.Proxy]
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
@@ -85,7 +79,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("add protocol address: %s", err)
return nil, fmt.Errorf("failed to add protocol address: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
@@ -161,13 +155,6 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
return nil
}
// SetProxy sets the inspection engine. When set, TCP connections are handed
// to it for protocol detection and rule evaluation instead of direct relay.
// Pass nil to disable inspection.
func (f *Forwarder) SetProxy(p *inspect.Proxy) {
f.proxy.Store(p)
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() {
f.cancel()
@@ -180,25 +167,6 @@ func (f *Forwarder) Stop() {
f.stack.Wait()
}
// CheckUDPPacket inspects a UDP payload against proxy rules before injection.
// This is called by the filter for QUIC SNI-based blocking.
// Returns true if the packet should be allowed, false if it should be dropped.
func (f *Forwarder) CheckUDPPacket(payload []byte, srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) bool {
p := f.proxy.Load()
if p == nil {
return true
}
dst := netip.AddrPortFrom(dstIP, dstPort)
src := inspect.SourceInfo{
IP: srcIP,
PolicyID: inspect.PolicyID(ruleID),
}
action := p.HandleUDPPacket(payload, dst, src)
return action != inspect.ActionBlock
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1)

View File

@@ -16,7 +16,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
"github.com/netbirdio/netbird/client/inspect"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -24,86 +23,6 @@ import (
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
// If the inspection engine is configured, accept the connection first and hand it off.
if p := f.proxy.Load(); p != nil {
f.handleTCPWithInspection(r, id, p)
return
}
f.handleTCPDirect(r, id)
}
// handleTCPWithInspection accepts the connection and hands it to the inspection
// engine. For allow decisions, the forwarder does its own relay (passthrough).
// For block/inspect, the engine handles everything internally.
func (f *Forwarder) handleTCPWithInspection(r *tcp.ForwarderRequest, id stack.TransportEndpointID, p *inspect.Proxy) {
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Error1("forwarder: create TCP endpoint for inspection: %v", epErr)
r.Complete(true)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
return
}
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
srcIP := netip.AddrFrom4(id.RemoteAddress.As4())
dstIP := netip.AddrFrom4(id.LocalAddress.As4())
dst := netip.AddrPortFrom(dstIP, id.LocalPort)
var policyID []byte
if ruleID, ok := f.getRuleID(srcIP, dstIP, id.RemotePort, id.LocalPort); ok {
policyID = ruleID
}
src := inspect.SourceInfo{
IP: srcIP,
PolicyID: inspect.PolicyID(policyID),
}
f.logger.Trace1("forwarder: handing TCP %v to inspection engine", epID(id))
go func() {
result, err := p.InspectTCP(f.ctx, inConn, dst, src)
if err != nil && err != inspect.ErrBlocked {
f.logger.Debug2("forwarder: inspection error for %v: %v", epID(id), err)
}
// Passthrough: engine returned allow, forwarder does the relay.
if result.PassthroughConn != nil {
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, dialErr := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if dialErr != nil {
f.logger.Trace2("forwarder: passthrough dial error for %v: %v", epID(id), dialErr)
if closeErr := result.PassthroughConn.Close(); closeErr != nil {
f.logger.Debug1("forwarder: close passthrough conn: %v", closeErr)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
return
}
f.proxyTCPPassthrough(id, result.PassthroughConn, outConn, ep, flowID)
return
}
// Engine handled it (block/inspect/HTTP). Capture stats and clean up.
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, rxPackets, txPackets)
}()
}
// handleTCPDirect handles TCP connections with direct relay (no proxy).
func (f *Forwarder) handleTCPDirect(r *tcp.ForwarderRequest, id stack.TransportEndpointID) {
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
@@ -123,6 +42,7 @@ func (f *Forwarder) handleTCPDirect(r *tcp.ForwarderRequest, id stack.TransportE
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
@@ -135,6 +55,7 @@ func (f *Forwarder) handleTCPDirect(r *tcp.ForwarderRequest, id stack.TransportE
return
}
// Complete the handshake
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
@@ -152,6 +73,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: inConn close error: %v", err)
}
@@ -210,66 +132,6 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}
// proxyTCPPassthrough relays traffic between a peeked inbound connection
// (from the inspection engine passthrough) and the outbound connection.
// It accepts net.Conn for inConn since the inspection engine wraps it in a peekConn.
func (f *Forwarder) proxyTCPPassthrough(id stack.TransportEndpointID, inConn net.Conn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: passthrough inConn close: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: passthrough outConn close: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesIn int64
bytesOut int64
errIn error
errOut error
)
go func() {
bytesIn, errIn = io.Copy(outConn, inConn)
cancel()
wg.Done()
}()
go func() {
bytesOut, errOut = io.Copy(inConn, outConn)
cancel()
wg.Done()
}()
wg.Wait()
if errIn != nil && !isClosedError(errIn) {
f.logger.Error2("proxyTCPPassthrough: copy error (in→out) for %s: %v", epID(id), errIn)
}
if errOut != nil && !isClosedError(errOut) {
f.logger.Error2("proxyTCPPassthrough: copy error (out→in) for %s: %v", epID(id), errOut)
}
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace5("forwarder: passthrough TCP %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesOut, txPackets, bytesIn)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesOut), uint64(bytesIn), rxPackets, txPackets)
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())

View File

@@ -144,8 +144,6 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
// case where an interface comes up between refreshes.
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}

View File

@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {
@@ -421,7 +421,6 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
// TODO: also delegate to nativeFirewall when available for kernel WG mode
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
@@ -467,22 +466,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// AddOutputDNAT delegates to the native firewall if available.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return fmt.Errorf("output DNAT not supported without native firewall")
}
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT delegates to the native firewall if available.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {

View File

@@ -18,7 +18,9 @@ type PeerRule struct {
protoLayer gopacket.LayerType
sPort *firewall.Port
dPort *firewall.Port
drop bool
drop bool
udpHook func([]byte) bool
}
// ID returns the rule id

View File

@@ -399,17 +399,21 @@ func TestTracePacket(t *testing.T) {
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
return true // drop (intercepted by hook)
})
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,

View File

@@ -28,7 +28,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
// CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string, extraOpts ...grpc.DialOption) (*grpc.ClientConn, error) {
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
// for js, the outer websocket layer takes care of tls
if tlsEnabled && runtime.GOOS != "js" {
@@ -46,7 +46,9 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
opts := []grpc.DialOption{
conn, err := grpc.DialContext(
connCtx,
addr,
transportOption,
WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
@@ -54,10 +56,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
}
opts = append(opts, extraOpts...)
conn, err := grpc.DialContext(connCtx, addr, opts...)
)
if err != nil {
return nil, fmt.Errorf("dial context: %w", err)
}

View File

@@ -5,18 +5,20 @@ package configurer
import (
"net"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/ipc"
)
func openUAPI(deviceName string) (net.Listener, error) {
uapiSock, err := ipc.UAPIOpen(deviceName)
if err != nil {
log.Errorf("failed to open uapi socket: %v", err)
return nil, err
}
listener, err := ipc.UAPIListen(deviceName, uapiSock)
if err != nil {
_ = uapiSock.Close()
log.Errorf("failed to listen on uapi socket: %v", err)
return nil, err
}

View File

@@ -54,14 +54,6 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
return wgCfg
}
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
return &WGUSPConfigurer{
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
}
}
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)

View File

@@ -15,17 +15,14 @@ type PacketFilter interface {
// FilterInbound filter incoming packets from external sources to host
FilterInbound(packetData []byte, size int) bool
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
// Hook function returns true if the packet should be dropped.
// Only one UDP hook is supported; calling again replaces the previous hook.
// Pass nil hook to remove.
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
// Hook function returns true if the packet should be dropped.
// Only one TCP hook is supported; calling again replaces the previous hook.
// Pass nil hook to remove.
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
}
// FilteredDevice to override Read or Write of packets

View File

@@ -79,7 +79,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
if cErr := tunIface.Close(); cErr != nil {

View File

@@ -34,28 +34,18 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// SetUDPPacketHook mocks base method.
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string)
return ret0
}
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
}
// SetTCPPacketHook mocks base method.
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
}
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
@@ -85,3 +75,17 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
}
// RemovePacketHook mocks base method.
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RemovePacketHook indicates an expected call of RemovePacketHook.
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
}

View File

@@ -0,0 +1,87 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
import (
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockPacketFilter is a mock of PacketFilter interface.
type MockPacketFilter struct {
ctrl *gomock.Controller
recorder *MockPacketFilterMockRecorder
}
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
type MockPacketFilterMockRecorder struct {
mock *MockPacketFilter
}
// NewMockPacketFilter creates a new mock instance.
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
mock := &MockPacketFilter{ctrl: ctrl}
mock.recorder = &MockPacketFilterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
}
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterInbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
}
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -1,212 +0,0 @@
package inspect
import (
"crypto"
"crypto/x509"
"net"
"net/netip"
"net/url"
"strings"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
// InspectResult holds the outcome of connection inspection.
type InspectResult struct {
// Action is the rule evaluation result.
Action Action
// PassthroughConn is the client connection with buffered peeked bytes.
// Non-nil only when Action is ActionAllow and the caller should relay
// (TLS passthrough or non-HTTP/TLS protocol). The caller takes ownership
// and is responsible for closing this connection.
PassthroughConn net.Conn
}
const (
// DefaultTProxyPort is the default TPROXY listener port for kernel mode.
// Override with NB_TPROXY_PORT environment variable.
DefaultTProxyPort = 22080
)
// Action determines how the proxy handles a matched connection.
type Action string
const (
// ActionAllow passes the connection through without decryption.
ActionAllow Action = "allow"
// ActionBlock denies the connection.
ActionBlock Action = "block"
// ActionInspect decrypts (MITM) and inspects the connection.
ActionInspect Action = "inspect"
)
// ProxyMode determines the proxy operating mode.
type ProxyMode string
const (
// ModeBuiltin uses the built-in proxy with rules and optional ICAP.
ModeBuiltin ProxyMode = "builtin"
// ModeEnvoy runs a local envoy sidecar for L7 processing.
// Go manages envoy lifecycle, config generation, and rule evaluation.
// USP path forwards via PROXY protocol v2; kernel path uses nftables redirect.
ModeEnvoy ProxyMode = "envoy"
// ModeExternal forwards all traffic to an external proxy.
ModeExternal ProxyMode = "external"
)
// PolicyID is the management policy identifier associated with a connection.
type PolicyID []byte
// MatchDomain reports whether target matches the pattern.
// If pattern starts with "*.", it matches any subdomain (but not the base itself).
// Otherwise it requires an exact match.
func MatchDomain(pattern, target domain.Domain) bool {
p := pattern.PunycodeString()
t := target.PunycodeString()
if strings.HasPrefix(p, "*.") {
base := p[2:]
return strings.HasSuffix(t, "."+base)
}
return p == t
}
// SourceInfo carries source identity context for rule evaluation.
// The source may be a direct WireGuard peer or a host behind
// a site-to-site gateway.
type SourceInfo struct {
// IP is the original source address from the packet.
IP netip.Addr
// PolicyID is the management policy that allowed this traffic
// through route ACLs.
PolicyID PolicyID
}
// ProtoType identifies a protocol handled by the proxy.
type ProtoType string
const (
ProtoHTTP ProtoType = "http"
ProtoHTTPS ProtoType = "https"
ProtoH2 ProtoType = "h2"
ProtoH3 ProtoType = "h3"
ProtoWebSocket ProtoType = "websocket"
ProtoOther ProtoType = "other"
)
// Rule defines a proxy inspection/filtering rule.
type Rule struct {
// ID uniquely identifies this rule.
ID id.RuleID
// Sources are the source CIDRs this rule applies to.
// Includes both direct peer IPs and routed networks behind gateways.
Sources []netip.Prefix
// Domains are the destination domain patterns to match (via SNI or Host header).
// Supports exact match ("example.com") and wildcard ("*.example.com").
Domains []domain.Domain
// Networks are the destination CIDRs to match.
Networks []netip.Prefix
// Ports are the destination ports to match. Empty means all ports.
Ports []uint16
// Protocols restricts which protocols this rule applies to.
// Empty means all protocols.
Protocols []ProtoType
// Paths are URL path patterns to match (HTTP only, requires inspect for HTTPS).
// Supports prefix ("/api/"), exact ("/login"), and wildcard ("/admin/*").
// Empty means all paths.
Paths []string
// Action determines what to do with matched connections.
Action Action
// Priority controls evaluation order. Lower values are evaluated first.
Priority int
}
// ICAPConfig holds ICAP service configuration.
type ICAPConfig struct {
// ReqModURL is the ICAP REQMOD service URL (e.g., icap://server:1344/reqmod).
ReqModURL *url.URL
// RespModURL is the ICAP RESPMOD service URL (e.g., icap://server:1344/respmod).
RespModURL *url.URL
// MaxConnections is the connection pool size. Zero uses a default.
MaxConnections int
}
// TLSConfig holds the MITM CA configuration for TLS inspection.
type TLSConfig struct {
// CA is the certificate authority used to sign dynamic certificates.
CA *x509.Certificate
// CAKey is the CA's private key.
CAKey crypto.PrivateKey
}
// Config holds the transparent proxy configuration.
type Config struct {
// Enabled controls whether the proxy is active.
Enabled bool
// Mode selects built-in or external proxy operation.
Mode ProxyMode
// ExternalURL is the upstream proxy URL for ModeExternal.
// Supports http:// and socks5:// schemes.
ExternalURL *url.URL
// DefaultAction applies when no rule matches a connection.
DefaultAction Action
// RedirectSources are the source CIDRs whose traffic should be intercepted.
// Admin decides: "activate for these users/subnets."
// Used for both kernel TPROXY rules and userspace forwarder source filtering.
RedirectSources []netip.Prefix
// RedirectPorts are the destination ports to intercept. Empty means all ports.
RedirectPorts []uint16
// Rules are the proxy inspection/filtering rules, evaluated in Priority order.
Rules []Rule
// ICAP holds ICAP service configuration. Nil disables ICAP.
ICAP *ICAPConfig
// TLS holds the MITM CA. Nil means no MITM capability (ActionInspect rules ignored).
TLS *TLSConfig
// Envoy configuration (ModeEnvoy only)
Envoy *EnvoyConfig
// ListenAddr is the TPROXY listen address for kernel mode.
// Zero value disables the TPROXY listener.
ListenAddr netip.AddrPort
// WGNetwork is the WireGuard overlay network prefix.
// The proxy blocks dialing destinations inside this network.
WGNetwork netip.Prefix
// LocalIPChecker reports whether an IP belongs to the routing peer.
// Used to prevent SSRF to local services. May be nil.
LocalIPChecker LocalIPChecker
}
// EnvoyConfig holds configuration for the envoy sidecar mode.
type EnvoyConfig struct {
// BinaryPath is the path to the envoy binary.
// Empty means search $PATH for "envoy".
BinaryPath string
// AdminPort is the port for envoy's admin API (health checks, stats).
// Zero means auto-assign.
AdminPort uint16
// Snippets are user-provided config fragments merged into the generated bootstrap.
Snippets *EnvoySnippets
}
// EnvoySnippets holds user-provided YAML fragments for envoy config customization.
// Only safe snippet types are allowed: filters (HTTP and network) and clusters
// needed as dependencies for filter services. Listeners and bootstrap overrides
// are not exposed since we manage the listener and bootstrap.
type EnvoySnippets struct {
// HTTPFilters is YAML injected into the HCM filter chain before the router filter.
// Used for ext_authz, rate limiting, Lua, Wasm, RBAC, JWT auth, etc.
HTTPFilters string
// NetworkFilters is YAML injected into the TLS filter chain before tcp_proxy.
// Used for network-level RBAC, rate limiting, ext_authz on raw TCP.
NetworkFilters string
// Clusters is YAML for additional upstream clusters referenced by filters.
// Needed when filters call external services (ext_authz backend, rate limit service).
Clusters string
}

View File

@@ -1,93 +0,0 @@
package inspect
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/domain"
)
func TestMatchDomain(t *testing.T) {
tests := []struct {
name string
pattern string
target string
want bool
}{
{
name: "exact match",
pattern: "example.com",
target: "example.com",
want: true,
},
{
name: "exact no match",
pattern: "example.com",
target: "other.com",
want: false,
},
{
name: "wildcard matches subdomain",
pattern: "*.example.com",
target: "foo.example.com",
want: true,
},
{
name: "wildcard matches deep subdomain",
pattern: "*.example.com",
target: "a.b.c.example.com",
want: true,
},
{
name: "wildcard does not match base",
pattern: "*.example.com",
target: "example.com",
want: false,
},
{
name: "wildcard does not match unrelated",
pattern: "*.example.com",
target: "foo.other.com",
want: false,
},
{
name: "case insensitive exact match",
pattern: "Example.COM",
target: "example.com",
want: true,
},
{
name: "case insensitive wildcard match",
pattern: "*.Example.COM",
target: "FOO.example.com",
want: true,
},
{
name: "wildcard does not match partial suffix",
pattern: "*.example.com",
target: "notexample.com",
want: false,
},
{
name: "unicode domain punycode match",
pattern: "*.münchen.de",
target: "sub.xn--mnchen-3ya.de",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pattern, err := domain.FromString(tt.pattern)
require.NoError(t, err)
target, err := domain.FromString(tt.target)
require.NoError(t, err)
got := MatchDomain(pattern, target)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -1,25 +0,0 @@
package inspect
import (
"net"
"syscall"
)
// newOutboundDialer creates a net.Dialer that clears the socket fwmark.
// In kernel TPROXY mode, accepted connections inherit the TPROXY fwmark.
// Without clearing it, outbound connections from the proxy would match
// the ip rule (fwmark -> local loopback) and loop back to the proxy
// instead of reaching the real destination.
func newOutboundDialer() net.Dialer {
return net.Dialer{
Control: func(_, _ string, c syscall.RawConn) error {
var sockErr error
if err := c.Control(func(fd uintptr) {
sockErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, 0)
}); err != nil {
return err
}
return sockErr
},
}
}

View File

@@ -1,11 +0,0 @@
//go:build !linux
package inspect
import "net"
// newOutboundDialer returns a plain dialer on non-Linux platforms.
// TPROXY is Linux-only, so no fwmark clearing is needed.
func newOutboundDialer() net.Dialer {
return net.Dialer{}
}

View File

@@ -1,298 +0,0 @@
package inspect
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
envoyStartTimeout = 15 * time.Second
envoyHealthInterval = 500 * time.Millisecond
envoyStopTimeout = 10 * time.Second
envoyDrainTime = 5
)
// envoyManager manages the lifecycle of an envoy sidecar process.
type envoyManager struct {
log *log.Entry
cmd *exec.Cmd
configPath string
listenPort uint16
adminPort uint16
cancel context.CancelFunc
blockPagePath string
mu sync.Mutex
running bool
}
// startEnvoy finds the envoy binary, generates config, and spawns the process.
// It blocks until envoy reports healthy or the timeout expires.
func startEnvoy(ctx context.Context, logger *log.Entry, config Config) (*envoyManager, error) {
envCfg := config.Envoy
if envCfg == nil {
return nil, fmt.Errorf("envoy config is nil")
}
binaryPath, err := findEnvoyBinary(envCfg.BinaryPath)
if err != nil {
return nil, fmt.Errorf("find envoy binary: %w", err)
}
// Pick admin port
adminPort := envCfg.AdminPort
if adminPort == 0 {
p, err := findFreePort()
if err != nil {
return nil, fmt.Errorf("find free admin port: %w", err)
}
adminPort = p
}
// Pick listener port
listenPort, err := findFreePort()
if err != nil {
return nil, fmt.Errorf("find free listener port: %w", err)
}
// Use a private temp directory (0700) to prevent local attackers from
// replacing the config file between write and envoy read.
configDir, err := os.MkdirTemp("", "nb-envoy-*")
if err != nil {
return nil, fmt.Errorf("create envoy config directory: %w", err)
}
// Write the block page HTML for envoy's direct_response to reference.
blockPagePath := filepath.Join(configDir, "block.html")
blockHTML := fmt.Sprintf(blockPageHTML, "blocked domain", "this domain")
if err := os.WriteFile(blockPagePath, []byte(blockHTML), 0600); err != nil {
return nil, fmt.Errorf("write envoy block page: %w", err)
}
// Generate config with the block page path embedded.
bootstrap, err := generateBootstrap(config, listenPort, adminPort, blockPagePath)
if err != nil {
return nil, fmt.Errorf("generate envoy bootstrap: %w", err)
}
configPath := filepath.Join(configDir, "bootstrap.yaml")
if err := os.WriteFile(configPath, bootstrap, 0600); err != nil {
return nil, fmt.Errorf("write envoy config: %w", err)
}
ctx, cancel := context.WithCancel(ctx)
cmd := exec.CommandContext(ctx, binaryPath,
"-c", configPath,
"--drain-time-s", fmt.Sprintf("%d", envoyDrainTime),
)
// Pipe envoy output to our logger.
cmd.Stdout = &logWriter{entry: logger, level: log.DebugLevel}
cmd.Stderr = &logWriter{entry: logger, level: log.WarnLevel}
if err := cmd.Start(); err != nil {
cancel()
os.Remove(configPath)
return nil, fmt.Errorf("start envoy: %w", err)
}
mgr := &envoyManager{
log: logger,
cmd: cmd,
configPath: configPath,
listenPort: listenPort,
adminPort: adminPort,
blockPagePath: blockPagePath,
cancel: cancel,
running: true,
}
// Wait for envoy to become healthy.
if err := mgr.waitHealthy(ctx); err != nil {
mgr.Stop()
return nil, fmt.Errorf("wait for envoy readiness: %w", err)
}
logger.Infof("inspect: envoy started (pid=%d, listen=%d, admin=%d)", cmd.Process.Pid, listenPort, adminPort)
// Monitor process exit in background.
go mgr.monitor()
return mgr, nil
}
// ListenAddr returns the address envoy listens on for forwarded connections.
func (m *envoyManager) ListenAddr() netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), m.listenPort)
}
// AdminAddr returns the envoy admin API address.
func (m *envoyManager) AdminAddr() string {
return fmt.Sprintf("127.0.0.1:%d", m.adminPort)
}
// Reload writes a new config and sends SIGHUP to envoy.
func (m *envoyManager) Reload(config Config) error {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return fmt.Errorf("envoy is not running")
}
bootstrap, err := generateBootstrap(config, m.listenPort, m.adminPort, m.blockPagePath)
if err != nil {
return fmt.Errorf("generate envoy bootstrap: %w", err)
}
if err := os.WriteFile(m.configPath, bootstrap, 0600); err != nil {
return fmt.Errorf("write envoy config: %w", err)
}
if err := signalReload(m.cmd.Process); err != nil {
return fmt.Errorf("signal envoy reload: %w", err)
}
m.log.Debugf("inspect: envoy config reloaded")
return nil
}
// Healthy checks the envoy admin API /ready endpoint.
func (m *envoyManager) Healthy() bool {
resp, err := http.Get(fmt.Sprintf("http://%s/ready", m.AdminAddr()))
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
// Stop terminates the envoy process and cleans up.
func (m *envoyManager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return
}
m.running = false
m.cancel()
if m.cmd.Process != nil {
done := make(chan struct{})
go func() {
m.cmd.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(envoyStopTimeout):
m.log.Warnf("inspect: envoy did not exit in %s, killing", envoyStopTimeout)
m.cmd.Process.Kill()
<-done
}
}
os.RemoveAll(filepath.Dir(m.configPath))
m.log.Infof("inspect: envoy stopped")
}
// waitHealthy polls the admin API until envoy is ready or timeout.
func (m *envoyManager) waitHealthy(ctx context.Context) error {
deadline := time.After(envoyStartTimeout)
ticker := time.NewTicker(envoyHealthInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-deadline:
return fmt.Errorf("envoy not ready after %s", envoyStartTimeout)
case <-ticker.C:
if m.Healthy() {
return nil
}
}
}
}
// monitor watches for unexpected envoy exits.
func (m *envoyManager) monitor() {
err := m.cmd.Wait()
m.mu.Lock()
wasRunning := m.running
m.running = false
m.mu.Unlock()
if wasRunning {
m.log.Errorf("inspect: envoy exited unexpectedly: %v", err)
}
}
// findEnvoyBinary resolves the envoy binary path.
func findEnvoyBinary(configPath string) (string, error) {
if configPath != "" {
if _, err := os.Stat(configPath); err != nil {
return "", fmt.Errorf("envoy binary not found at %s: %w", configPath, err)
}
return configPath, nil
}
path, err := exec.LookPath("envoy")
if err != nil {
return "", fmt.Errorf("envoy not found in PATH: %w", err)
}
return path, nil
}
// findFreePort asks the OS for an available TCP port.
func findFreePort() (uint16, error) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return 0, err
}
port := uint16(ln.Addr().(*net.TCPAddr).Port)
ln.Close()
return port, nil
}
// logWriter adapts log.Entry to io.Writer for piping process output.
type logWriter struct {
entry *log.Entry
level log.Level
}
func (w *logWriter) Write(p []byte) (int, error) {
msg := strings.TrimRight(string(p), "\n\r")
if msg == "" {
return len(p), nil
}
switch w.level {
case log.WarnLevel:
w.entry.Warn(msg)
default:
w.entry.Debug(msg)
}
return len(p), nil
}
// Ensure logWriter satisfies io.Writer.
var _ io.Writer = (*logWriter)(nil)

View File

@@ -1,382 +0,0 @@
package inspect
import (
"bytes"
"fmt"
"strings"
"text/template"
)
// envoyBootstrapTmpl generates the full envoy bootstrap with rule translation.
// TLS rules become per-SNI filter chains; HTTP rules become per-domain virtual hosts.
var envoyBootstrapTmpl = template.Must(template.New("bootstrap").Funcs(template.FuncMap{
"quote": func(s string) string { return fmt.Sprintf("%q", s) },
}).Parse(`node:
id: netbird-inspect
cluster: netbird
admin:
address:
socket_address:
address: 127.0.0.1
port_value: {{.AdminPort}}
static_resources:
listeners:
- name: inspect_listener
address:
socket_address:
address: 127.0.0.1
port_value: {{.ListenPort}}
listener_filters:
- name: envoy.filters.listener.proxy_protocol
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.listener.proxy_protocol.v3.ProxyProtocol
- name: envoy.filters.listener.tls_inspector
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.listener.tls_inspector.v3.TlsInspector
filter_chains:
{{- /* TLS filter chains: per-SNI block/allow + default */ -}}
{{- range .TLSChains}}
- filter_chain_match:
transport_protocol: tls
{{- if .ServerNames}}
server_names:
{{- range .ServerNames}}
- {{quote .}}
{{- end}}
{{- end}}
filters:
{{$.NetworkFiltersSnippet}} - name: envoy.filters.network.tcp_proxy
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.tcp_proxy.v3.TcpProxy
stat_prefix: {{.StatPrefix}}
cluster: original_dst
access_log:
- name: envoy.access_loggers.stderr
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StderrAccessLog
log_format:
text_format: "[%START_TIME%] tcp %DOWNSTREAM_REMOTE_ADDRESS% -> %UPSTREAM_HOST% %RESPONSE_FLAGS% %DURATION%ms\n"
{{- end}}
{{- /* Plain HTTP filter chain with per-domain virtual hosts */}}
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
stat_prefix: inspect_http
access_log:
- name: envoy.access_loggers.stderr
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StderrAccessLog
log_format:
text_format: "[%START_TIME%] http %DOWNSTREAM_REMOTE_ADDRESS% %REQ(:AUTHORITY)% %REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% %RESPONSE_CODE% %RESPONSE_FLAGS% %DURATION%ms\n"
http_filters:
{{.HTTPFiltersSnippet}} - name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
route_config:
virtual_hosts:
{{- range .VirtualHosts}}
- name: {{.Name}}
domains: [{{.DomainsStr}}]
routes:
{{- range .Routes}}
- match:
prefix: "{{if .PathPrefix}}{{.PathPrefix}}{{else}}/{{end}}"
{{- if .Block}}
direct_response:
status: 403
body:
filename: "{{$.BlockPagePath}}"
{{- else}}
route:
cluster: original_dst
{{- end}}
{{- end}}
{{- end}}
clusters:
- name: original_dst
type: ORIGINAL_DST
lb_policy: CLUSTER_PROVIDED
connect_timeout: 10s
{{.ExtraClusters}}`))
// tlsChain represents a TLS filter chain entry for the template.
// All TLS chains are passthrough (block decisions happen in Go before envoy).
type tlsChain struct {
// ServerNames restricts this chain to specific SNIs. Empty is catch-all.
ServerNames []string
StatPrefix string
}
// envoyRoute represents a single route entry within a virtual host.
type envoyRoute struct {
// PathPrefix for envoy prefix match. Empty means catch-all "/".
PathPrefix string
Block bool
}
// virtualHost represents an HTTP virtual host entry for the template.
type virtualHost struct {
Name string
// DomainsStr is pre-formatted for the template: "a", "b".
DomainsStr string
Routes []envoyRoute
}
type bootstrapData struct {
AdminPort uint16
ListenPort uint16
BlockPagePath string
TLSChains []tlsChain
VirtualHosts []virtualHost
HTTPFiltersSnippet string
NetworkFiltersSnippet string
ExtraClusters string
}
// generateBootstrap produces the envoy bootstrap YAML from the inspect config.
// Translates inspection rules into envoy-native per-SNI and per-domain routing.
// blockPagePath is the path to the HTML block page file served by direct_response.
func generateBootstrap(config Config, listenPort, adminPort uint16, blockPagePath string) ([]byte, error) {
data := bootstrapData{
AdminPort: adminPort,
BlockPagePath: blockPagePath,
ListenPort: listenPort,
TLSChains: buildTLSChains(config),
VirtualHosts: buildVirtualHosts(config),
}
if config.Envoy != nil && config.Envoy.Snippets != nil {
s := config.Envoy.Snippets
data.HTTPFiltersSnippet = indentSnippet(s.HTTPFilters, 18)
data.NetworkFiltersSnippet = indentSnippet(s.NetworkFilters, 12)
data.ExtraClusters = indentSnippet(s.Clusters, 4)
}
var buf bytes.Buffer
if err := envoyBootstrapTmpl.Execute(&buf, data); err != nil {
return nil, fmt.Errorf("execute bootstrap template: %w", err)
}
return buf.Bytes(), nil
}
// buildTLSChains translates inspection rules into envoy TLS filter chains.
// Block rules -> per-SNI chain routing to blackhole.
// Allow rules (when default=block) -> per-SNI chain routing to original_dst.
// Default chain follows DefaultAction.
func buildTLSChains(config Config) []tlsChain {
// TLS block decisions happen in Go before forwarding to envoy, so we only
// generate allow/passthrough chains here. Envoy can't cleanly close a TLS
// connection without completing a handshake, so blocked SNIs never reach envoy.
var allowed []string
for _, rule := range config.Rules {
if !ruleTouchesProtocol(rule, ProtoHTTPS, ProtoH2) {
continue
}
for _, d := range rule.Domains {
sni := d.PunycodeString()
if rule.Action == ActionAllow || rule.Action == ActionInspect {
allowed = append(allowed, sni)
}
}
}
var chains []tlsChain
if len(allowed) > 0 && config.DefaultAction == ActionBlock {
chains = append(chains, tlsChain{
ServerNames: allowed,
StatPrefix: "tls_allowed",
})
}
// Default catch-all: passthrough (blocked SNIs never arrive here)
chains = append(chains, tlsChain{
StatPrefix: "tls_default",
})
return chains
}
// buildVirtualHosts translates inspection rules into envoy HTTP virtual hosts.
// Groups rules by domain, generates per-path routes within each virtual host.
func buildVirtualHosts(config Config) []virtualHost {
// Group rules by domain for per-domain virtual hosts.
type domainRules struct {
domains []string
routes []envoyRoute
}
domainRouteMap := make(map[string][]envoyRoute)
for _, rule := range config.Rules {
if !ruleTouchesProtocol(rule, ProtoHTTP, ProtoWebSocket) {
continue
}
isBlock := rule.Action == ActionBlock
// Rules without domains or paths are handled by the default action.
if len(rule.Domains) == 0 && len(rule.Paths) == 0 {
continue
}
// Build routes for this rule's paths
var routes []envoyRoute
if len(rule.Paths) > 0 {
for _, p := range rule.Paths {
// Convert our path patterns to envoy prefix match.
// Strip trailing * for envoy prefix matching.
prefix := strings.TrimSuffix(p, "*")
routes = append(routes, envoyRoute{PathPrefix: prefix, Block: isBlock})
}
} else {
routes = append(routes, envoyRoute{Block: isBlock})
}
if len(rule.Domains) > 0 {
for _, d := range rule.Domains {
host := d.PunycodeString()
domainRouteMap[host] = append(domainRouteMap[host], routes...)
}
} else {
// No domain: applies to all, add to default host
domainRouteMap["*"] = append(domainRouteMap["*"], routes...)
}
}
var hosts []virtualHost
idx := 0
// Per-domain virtual hosts with path routes
for domain, routes := range domainRouteMap {
if domain == "*" {
continue
}
// Add a catch-all route after path-specific routes.
// The catch-all follows the default action.
routes = append(routes, envoyRoute{Block: config.DefaultAction == ActionBlock})
hosts = append(hosts, virtualHost{
Name: fmt.Sprintf("domain_%d", idx),
DomainsStr: fmt.Sprintf("%q", domain),
Routes: routes,
})
idx++
}
// Default virtual host (catch-all for unmatched domains)
defaultRoutes := domainRouteMap["*"]
defaultRoutes = append(defaultRoutes, envoyRoute{Block: config.DefaultAction == ActionBlock})
hosts = append(hosts, virtualHost{
Name: "default",
DomainsStr: `"*"`,
Routes: defaultRoutes,
})
return hosts
}
// ruleTouchesProtocol returns true if the rule's protocol list includes any of the given protocols,
// or if the protocol list is empty (matches all).
func ruleTouchesProtocol(rule Rule, protos ...ProtoType) bool {
if len(rule.Protocols) == 0 {
return true
}
for _, rp := range rule.Protocols {
for _, p := range protos {
if rp == p {
return true
}
}
}
return false
}
// indentSnippet prepends each line of the YAML snippet with the given number of spaces.
// Returns empty string if snippet is empty.
func indentSnippet(snippet string, spaces int) string {
if snippet == "" {
return ""
}
prefix := make([]byte, spaces)
for i := range prefix {
prefix[i] = ' '
}
var buf bytes.Buffer
for i, line := range bytes.Split([]byte(snippet), []byte("\n")) {
if i > 0 {
buf.WriteByte('\n')
}
if len(line) > 0 {
buf.Write(prefix)
buf.Write(line)
}
}
buf.WriteByte('\n')
return buf.String()
}
// ValidateSnippets checks that user-provided snippets are safe to inject
// into the envoy config. Returns an error describing the first violation found.
//
// Validation rules:
// - Each snippet must be valid YAML (prevents syntax-level injection)
// - Snippets must not contain YAML document separators (--- or ...) that could
// break out of the indentation context
// - Snippets must only contain list items (starting with "- ") at the top level,
// matching what envoy expects for filters and clusters
func ValidateSnippets(snippets *EnvoySnippets) error {
if snippets == nil {
return nil
}
fields := []struct {
name string
value string
}{
{"http_filters", snippets.HTTPFilters},
{"network_filters", snippets.NetworkFilters},
{"clusters", snippets.Clusters},
}
for _, f := range fields {
if f.value == "" {
continue
}
if err := validateSnippetYAML(f.name, f.value); err != nil {
return err
}
}
return nil
}
func validateSnippetYAML(name, snippet string) error {
// Check for YAML document markers that could break template structure.
for _, line := range strings.Split(snippet, "\n") {
trimmed := strings.TrimSpace(line)
if trimmed == "---" || trimmed == "..." {
return fmt.Errorf("snippet %q: YAML document separators (--- or ...) are not allowed", name)
}
}
// Verify it's valid YAML by checking it doesn't cause template execution issues.
// We can't import yaml.v3 here without adding a dependency, so we do structural checks.
// Check for null bytes or control characters that could confuse YAML parsers.
for i, b := range []byte(snippet) {
if b == 0 {
return fmt.Errorf("snippet %q: null byte at position %d", name, i)
}
if b < 0x09 || (b > 0x0D && b < 0x20 && b != 0x1B) {
return fmt.Errorf("snippet %q: control character 0x%02x at position %d", name, b, i)
}
}
return nil
}

View File

@@ -1,88 +0,0 @@
package inspect
import (
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
)
// PROXY protocol v2 constants (RFC 7239 / HAProxy spec)
var proxyV2Signature = [12]byte{
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51,
0x55, 0x49, 0x54, 0x0A,
}
const (
proxyV2VersionCommand = 0x21 // version 2, PROXY command
proxyV2FamilyTCP4 = 0x11 // AF_INET, STREAM
proxyV2FamilyTCP6 = 0x21 // AF_INET6, STREAM
)
// forwardToEnvoy forwards a connection to the given envoy sidecar via PROXY protocol v2.
// The caller provides the envoy manager snapshot to avoid accessing p.envoy without lock.
func (p *Proxy) forwardToEnvoy(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo, em *envoyManager) error {
envoyAddr := em.ListenAddr()
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", envoyAddr.String())
if err != nil {
return fmt.Errorf("dial envoy at %s: %w", envoyAddr, err)
}
defer func() {
if err := conn.Close(); err != nil {
p.log.Debugf("close envoy conn: %v", err)
}
}()
if err := writeProxyV2Header(conn, src.IP, dst); err != nil {
return fmt.Errorf("write PROXY v2 header: %w", err)
}
p.log.Tracef("envoy: forwarded %s -> %s via PROXY v2", src.IP, dst)
return relay(ctx, pconn, conn)
}
// writeProxyV2Header writes a PROXY protocol v2 header to w.
// The header encodes the original source IP and the destination address:port.
func writeProxyV2Header(w net.Conn, srcIP netip.Addr, dst netip.AddrPort) error {
srcIP = srcIP.Unmap()
dstIP := dst.Addr().Unmap()
var (
family byte
addrs []byte
)
if srcIP.Is4() && dstIP.Is4() {
family = proxyV2FamilyTCP4
s4 := srcIP.As4()
d4 := dstIP.As4()
addrs = make([]byte, 12) // 4+4+2+2
copy(addrs[0:4], s4[:])
copy(addrs[4:8], d4[:])
binary.BigEndian.PutUint16(addrs[8:10], 0) // src port unknown
binary.BigEndian.PutUint16(addrs[10:12], dst.Port())
} else {
family = proxyV2FamilyTCP6
s16 := srcIP.As16()
d16 := dstIP.As16()
addrs = make([]byte, 36) // 16+16+2+2
copy(addrs[0:16], s16[:])
copy(addrs[16:32], d16[:])
binary.BigEndian.PutUint16(addrs[32:34], 0) // src port unknown
binary.BigEndian.PutUint16(addrs[34:36], dst.Port())
}
// Header: signature(12) + ver_cmd(1) + family(1) + len(2) + addrs
header := make([]byte, 16+len(addrs))
copy(header[0:12], proxyV2Signature[:])
header[12] = proxyV2VersionCommand
header[13] = family
binary.BigEndian.PutUint16(header[14:16], uint16(len(addrs)))
copy(header[16:], addrs)
_, err := w.Write(header)
return err
}

View File

@@ -1,13 +0,0 @@
//go:build !windows
package inspect
import (
"os"
"syscall"
)
// signalReload sends SIGHUP to the envoy process to trigger config reload.
func signalReload(p *os.Process) error {
return p.Signal(syscall.SIGHUP)
}

View File

@@ -1,13 +0,0 @@
//go:build windows
package inspect
import (
"fmt"
"os"
)
// signalReload is not supported on Windows. Envoy must be restarted.
func signalReload(_ *os.Process) error {
return fmt.Errorf("envoy config reload via signal not supported on Windows")
}

View File

@@ -1,229 +0,0 @@
package inspect
import (
"bufio"
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"time"
)
const (
externalDialTimeout = 10 * time.Second
)
// handleExternal forwards the connection to an external proxy.
// For TLS connections, it uses HTTP CONNECT to tunnel through the proxy.
// For HTTP connections, it rewrites the request to use the proxy.
func (p *Proxy) handleExternal(ctx context.Context, pconn *peekConn, dst netip.AddrPort) error {
p.mu.RLock()
proxyURL := p.config.ExternalURL
p.mu.RUnlock()
if proxyURL == nil {
return fmt.Errorf("external proxy URL not configured")
}
switch proxyURL.Scheme {
case "http", "https":
return p.externalHTTPProxy(ctx, pconn, dst, proxyURL)
case "socks5":
return p.externalSOCKS5(ctx, pconn, dst, proxyURL)
default:
return fmt.Errorf("unsupported external proxy scheme: %s", proxyURL.Scheme)
}
}
// externalHTTPProxy tunnels through an HTTP proxy using CONNECT.
func (p *Proxy) externalHTTPProxy(ctx context.Context, pconn *peekConn, dst netip.AddrPort, proxyURL *url.URL) error {
proxyAddr := proxyURL.Host
if _, _, err := net.SplitHostPort(proxyAddr); err != nil {
proxyAddr = net.JoinHostPort(proxyAddr, "8080")
}
proxyConn, err := (&net.Dialer{Timeout: externalDialTimeout}).DialContext(ctx, "tcp", proxyAddr)
if err != nil {
return fmt.Errorf("dial external proxy %s: %w", proxyAddr, err)
}
defer func() {
if err := proxyConn.Close(); err != nil {
p.log.Debugf("close external proxy conn: %v", err)
}
}()
connectReq := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", dst.String(), dst.String())
if proxyURL.User != nil {
connectReq += "Proxy-Authorization: Basic " + basicAuth(proxyURL.User) + "\r\n"
}
connectReq += "\r\n"
if _, err := io.WriteString(proxyConn, connectReq); err != nil {
return fmt.Errorf("send CONNECT to proxy: %w", err)
}
resp, err := http.ReadResponse(bufio.NewReader(proxyConn), nil)
if err != nil {
return fmt.Errorf("read CONNECT response: %w", err)
}
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close CONNECT resp body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
}
return relay(ctx, pconn, proxyConn)
}
// externalSOCKS5 tunnels through a SOCKS5 proxy.
func (p *Proxy) externalSOCKS5(ctx context.Context, pconn *peekConn, dst netip.AddrPort, proxyURL *url.URL) error {
proxyAddr := proxyURL.Host
if _, _, err := net.SplitHostPort(proxyAddr); err != nil {
proxyAddr = net.JoinHostPort(proxyAddr, "1080")
}
proxyConn, err := (&net.Dialer{Timeout: externalDialTimeout}).DialContext(ctx, "tcp", proxyAddr)
if err != nil {
return fmt.Errorf("dial SOCKS5 proxy %s: %w", proxyAddr, err)
}
defer func() {
if err := proxyConn.Close(); err != nil {
p.log.Debugf("close SOCKS5 proxy conn: %v", err)
}
}()
if err := socks5Handshake(proxyConn, dst, proxyURL.User); err != nil {
return fmt.Errorf("SOCKS5 handshake: %w", err)
}
return relay(ctx, pconn, proxyConn)
}
// socks5Handshake performs the SOCKS5 handshake to connect through the proxy.
func socks5Handshake(conn net.Conn, dst netip.AddrPort, userinfo *url.Userinfo) error {
needAuth := userinfo != nil
// Greeting
var methods []byte
if needAuth {
methods = []byte{0x00, 0x02} // no auth, username/password
} else {
methods = []byte{0x00} // no auth
}
greeting := append([]byte{0x05, byte(len(methods))}, methods...)
if _, err := conn.Write(greeting); err != nil {
return fmt.Errorf("send greeting: %w", err)
}
// Server method selection
var methodResp [2]byte
if _, err := io.ReadFull(conn, methodResp[:]); err != nil {
return fmt.Errorf("read method selection: %w", err)
}
if methodResp[0] != 0x05 {
return fmt.Errorf("unexpected SOCKS version: %d", methodResp[0])
}
// Handle authentication if selected
if methodResp[1] == 0x02 {
if err := socks5Auth(conn, userinfo); err != nil {
return err
}
} else if methodResp[1] != 0x00 {
return fmt.Errorf("unsupported SOCKS5 auth method: %d", methodResp[1])
}
// Connection request
addr := dst.Addr()
var addrBytes []byte
if addr.Is4() {
a4 := addr.As4()
addrBytes = append([]byte{0x01}, a4[:]...) // IPv4
} else {
a16 := addr.As16()
addrBytes = append([]byte{0x04}, a16[:]...) // IPv6
}
port := dst.Port()
connectReq := append([]byte{0x05, 0x01, 0x00}, addrBytes...)
connectReq = append(connectReq, byte(port>>8), byte(port))
if _, err := conn.Write(connectReq); err != nil {
return fmt.Errorf("send connect request: %w", err)
}
// Read response (minimum 10 bytes for IPv4)
var respHeader [4]byte
if _, err := io.ReadFull(conn, respHeader[:]); err != nil {
return fmt.Errorf("read connect response: %w", err)
}
if respHeader[1] != 0x00 {
return fmt.Errorf("SOCKS5 connect failed: status %d", respHeader[1])
}
// Skip bound address
switch respHeader[3] {
case 0x01: // IPv4
var skip [4 + 2]byte
if _, err := io.ReadFull(conn, skip[:]); err != nil {
return fmt.Errorf("read SOCKS5 bound IPv4 address: %w", err)
}
case 0x04: // IPv6
var skip [16 + 2]byte
if _, err := io.ReadFull(conn, skip[:]); err != nil {
return fmt.Errorf("read SOCKS5 bound IPv6 address: %w", err)
}
case 0x03: // Domain
var dLen [1]byte
if _, err := io.ReadFull(conn, dLen[:]); err != nil {
return fmt.Errorf("read domain length: %w", err)
}
skip := make([]byte, int(dLen[0])+2)
if _, err := io.ReadFull(conn, skip); err != nil {
return fmt.Errorf("read SOCKS5 bound domain address: %w", err)
}
}
return nil
}
func socks5Auth(conn net.Conn, userinfo *url.Userinfo) error {
if userinfo == nil {
return fmt.Errorf("SOCKS5 auth required but no credentials provided")
}
user := userinfo.Username()
pass, _ := userinfo.Password()
// Username/password auth (RFC 1929)
auth := []byte{0x01, byte(len(user))}
auth = append(auth, []byte(user)...)
auth = append(auth, byte(len(pass)))
auth = append(auth, []byte(pass)...)
if _, err := conn.Write(auth); err != nil {
return fmt.Errorf("send auth: %w", err)
}
var resp [2]byte
if _, err := io.ReadFull(conn, resp[:]); err != nil {
return fmt.Errorf("read auth response: %w", err)
}
if resp[1] != 0x00 {
return fmt.Errorf("SOCKS5 auth failed: status %d", resp[1])
}
return nil
}
func basicAuth(userinfo *url.Userinfo) string {
user := userinfo.Username()
pass, _ := userinfo.Password()
return base64.StdEncoding.EncodeToString([]byte(user + ":" + pass))
}

View File

@@ -1,532 +0,0 @@
package inspect
import (
"bufio"
"context"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"strings"
"sync"
"time"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
headerUpgrade = "Upgrade"
valueWebSocket = "websocket"
)
// inspectHTTP runs the HTTP inspection pipeline on decrypted traffic.
// It handles HTTP/1.1 (request-response loop), HTTP/2 (via Go stdlib reverse proxy),
// and WebSocket upgrade detection.
func (p *Proxy) inspectHTTP(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, proto string) error {
if proto == "h2" {
return p.inspectH2(ctx, client, remote, dst, sni, src)
}
return p.inspectH1(ctx, client, remote, dst, sni, src)
}
// inspectH1 handles HTTP/1.1 request-response inspection in a loop.
func (p *Proxy) inspectH1(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error {
clientReader := bufio.NewReader(client)
remoteReader := bufio.NewReader(remote)
for {
if ctx.Err() != nil {
return ctx.Err()
}
// Set idle timeout between requests to prevent connection hogging.
if err := client.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
return fmt.Errorf("set idle deadline: %w", err)
}
req, err := http.ReadRequest(clientReader)
if err != nil {
if isClosedErr(err) {
return nil
}
return fmt.Errorf("read HTTP request: %w", err)
}
if err := client.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("clear read deadline: %w", err)
}
// Re-evaluate rules based on Host header if SNI was empty
host := hostFromRequest(req, sni)
// Domain fronting: Host header doesn't match TLS SNI
if isDomainFronting(req, sni) {
p.log.Debugf("domain fronting detected: SNI=%s Host=%s", sni.PunycodeString(), host.PunycodeString())
writeBlockResponse(client, req, host)
return ErrBlocked
}
proto := ProtoHTTP
if isWebSocketUpgrade(req) {
proto = ProtoWebSocket
}
action := p.evaluateAction(src.IP, host, dst, proto, req.URL.Path)
if action == ActionBlock {
p.log.Debugf("block: HTTP %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString())
writeBlockResponse(client, req, host)
return ErrBlocked
}
p.log.Tracef("allow: HTTP %s %s (host=%s, action=%s)", req.Method, req.URL.Path, host.PunycodeString(), action)
// ICAP REQMOD: send request for inspection.
// Snapshot ICAP client under lock to avoid use-after-close races.
p.mu.RLock()
icap := p.icap
p.mu.RUnlock()
if icap != nil {
modified, err := icap.ReqMod(req)
if err != nil {
p.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err)
// Fail-closed: block on ICAP error
writeBlockResponse(client, req, host)
return fmt.Errorf("ICAP REQMOD: %w", err)
}
req = modified
}
if isWebSocketUpgrade(req) {
return p.handleWebSocket(ctx, req, client, clientReader, remote, remoteReader)
}
removeHopByHopHeaders(req.Header)
if err := req.Write(remote); err != nil {
return fmt.Errorf("forward request: %w", err)
}
resp, err := http.ReadResponse(remoteReader, req)
if err != nil {
return fmt.Errorf("read HTTP response: %w", err)
}
// ICAP RESPMOD: send response for inspection
if icap != nil {
modified, err := icap.RespMod(req, resp)
if err != nil {
p.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err)
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close resp body: %v", err)
}
writeBlockResponse(client, req, host)
return fmt.Errorf("ICAP RESPMOD: %w", err)
}
resp = modified
}
removeHopByHopHeaders(resp.Header)
if err := resp.Write(client); err != nil {
if closeErr := resp.Body.Close(); closeErr != nil {
p.log.Debugf("close resp body: %v", closeErr)
}
return fmt.Errorf("forward response: %w", err)
}
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close resp body: %v", err)
}
// Connection: close means we're done
if resp.Close || req.Close {
return nil
}
}
}
// inspectH2 proxies HTTP/2 traffic using Go's http stack.
// Client and remote are already-established TLS connections with h2 negotiated.
func (p *Proxy) inspectH2(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error {
// For h2 MITM inspection, we use a local http.Server reading from the client
// connection and an http.Transport writing to the remote connection.
//
// The transport is configured to use the existing TLS connection to the
// real server. The handler inspects each request/response pair.
transport := &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return remote, nil
},
DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return remote, nil
},
ForceAttemptHTTP2: true,
}
handler := &h2InspectionHandler{
proxy: p,
transport: transport,
dst: dst,
sni: sni,
src: src,
}
server := &http.Server{
Handler: handler,
}
// Serve the single client connection.
// ServeConn blocks until the connection is done.
errCh := make(chan error, 1)
go func() {
// http.Server doesn't have a direct ServeConn for h2,
// so we use Serve with a single-connection listener.
ln := &singleConnListener{conn: client}
errCh <- server.Serve(ln)
}()
select {
case <-ctx.Done():
if err := server.Close(); err != nil {
p.log.Debugf("close h2 server: %v", err)
}
return ctx.Err()
case err := <-errCh:
if err == http.ErrServerClosed {
return nil
}
return err
}
}
// h2InspectionHandler inspects each HTTP/2 request/response pair.
type h2InspectionHandler struct {
proxy *Proxy
transport http.RoundTripper
dst netip.AddrPort
sni domain.Domain
src SourceInfo
}
func (h *h2InspectionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
host := hostFromRequest(req, h.sni)
if isDomainFronting(req, h.sni) {
h.proxy.log.Debugf("domain fronting detected: SNI=%s Host=%s", h.sni.PunycodeString(), host.PunycodeString())
writeBlockPage(w, host)
return
}
action := h.proxy.evaluateAction(h.src.IP, host, h.dst, ProtoH2, req.URL.Path)
if action == ActionBlock {
h.proxy.log.Debugf("block: H2 %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString())
writeBlockPage(w, host)
return
}
// ICAP REQMOD
if h.proxy.icap != nil {
modified, err := h.proxy.icap.ReqMod(req)
if err != nil {
h.proxy.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err)
writeBlockPage(w, host)
return
}
req = modified
}
// Forward to upstream
req.URL.Scheme = "https"
req.URL.Host = h.sni.PunycodeString()
req.RequestURI = ""
resp, err := h.transport.RoundTrip(req)
if err != nil {
h.proxy.log.Debugf("h2 upstream error for %s: %v", host.PunycodeString(), err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
defer func() {
if err := resp.Body.Close(); err != nil {
h.proxy.log.Debugf("close h2 resp body: %v", err)
}
}()
// ICAP RESPMOD
if h.proxy.icap != nil {
modified, err := h.proxy.icap.RespMod(req, resp)
if err != nil {
h.proxy.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err)
writeBlockPage(w, host)
return
}
resp = modified
}
// Copy response headers and body
for k, vals := range resp.Header {
for _, v := range vals {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
if _, err := io.Copy(w, resp.Body); err != nil {
h.proxy.log.Debugf("h2 response copy error: %v", err)
}
}
// handleWebSocket completes the WebSocket upgrade and relays frames bidirectionally.
func (p *Proxy) handleWebSocket(ctx context.Context, req *http.Request, client io.ReadWriter, clientReader *bufio.Reader, remote io.ReadWriter, remoteReader *bufio.Reader) error {
if err := req.Write(remote); err != nil {
return fmt.Errorf("forward WebSocket upgrade: %w", err)
}
resp, err := http.ReadResponse(remoteReader, req)
if err != nil {
return fmt.Errorf("read WebSocket upgrade response: %w", err)
}
if err := resp.Write(client); err != nil {
if closeErr := resp.Body.Close(); closeErr != nil {
p.log.Debugf("close ws resp body: %v", closeErr)
}
return fmt.Errorf("forward WebSocket upgrade response: %w", err)
}
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close ws resp body: %v", err)
}
if resp.StatusCode != http.StatusSwitchingProtocols {
return fmt.Errorf("WebSocket upgrade rejected: status %d", resp.StatusCode)
}
p.log.Tracef("allow: WebSocket upgrade for %s", req.Host)
// Relay WebSocket frames bidirectionally.
// clientReader/remoteReader may have buffered data.
clientConn := mergeReadWriter(clientReader, client)
remoteConn := mergeReadWriter(remoteReader, remote)
return relayRW(ctx, clientConn, remoteConn)
}
// hostFromRequest extracts a domain.Domain from the HTTP request Host header,
// falling back to the SNI if Host is empty or an IP.
func hostFromRequest(req *http.Request, fallback domain.Domain) domain.Domain {
host := req.Host
if host == "" {
return fallback
}
// Strip port if present
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
// If it's an IP address, use the SNI fallback
if _, err := netip.ParseAddr(host); err == nil {
return fallback
}
d, err := domain.FromString(host)
if err != nil {
return fallback
}
return d
}
// isDomainFronting detects domain fronting: the Host header doesn't match the
// SNI used during the TLS handshake. Only meaningful when SNI is non-empty
// (i.e., we're in MITM mode and know the original SNI).
func isDomainFronting(req *http.Request, sni domain.Domain) bool {
if sni == "" {
return false
}
host := hostFromRequest(req, "")
if host == "" {
return false
}
// Host should match SNI or be a subdomain of SNI
if host == sni {
return false
}
// Allow www.example.com when SNI is example.com
sniStr := sni.PunycodeString()
hostStr := host.PunycodeString()
if strings.HasSuffix(hostStr, "."+sniStr) {
return false
}
return true
}
func isWebSocketUpgrade(req *http.Request) bool {
return strings.EqualFold(req.Header.Get(headerUpgrade), valueWebSocket)
}
// writeBlockPage writes the styled HTML block page to an http.ResponseWriter (H2 path).
func writeBlockPage(w http.ResponseWriter, host domain.Domain) {
hostname := host.PunycodeString()
body := fmt.Sprintf(blockPageHTML, hostname, hostname)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set("Cache-Control", "no-store")
w.WriteHeader(http.StatusForbidden)
io.WriteString(w, body)
}
func writeBlockResponse(w io.Writer, _ *http.Request, host domain.Domain) {
hostname := host.PunycodeString()
body := fmt.Sprintf(blockPageHTML, hostname, hostname)
resp := &http.Response{
StatusCode: http.StatusForbidden,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
ContentLength: int64(len(body)),
Body: io.NopCloser(strings.NewReader(body)),
}
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
resp.Header.Set("Connection", "close")
resp.Header.Set("Cache-Control", "no-store")
_ = resp.Write(w)
}
// blockPageHTML is the self-contained HTML block page.
// Uses NetBird dark theme with orange accent. Two format args: page title domain, displayed domain.
const blockPageHTML = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<title>Blocked - %s</title>
<style>
*{margin:0;padding:0;box-sizing:border-box}
body{background:#181a1d;color:#d1d5db;font-family:-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,sans-serif;min-height:100vh;display:flex;align-items:center;justify-content:center}
.c{text-align:center;max-width:460px;padding:2rem}
.shield{width:56px;height:56px;margin:0 auto 1.5rem;border-radius:16px;background:#2b2f33;display:flex;align-items:center;justify-content:center}
.shield svg{width:28px;height:28px;color:#f68330}
.code{font-size:.8rem;font-weight:500;color:#f68330;font-family:ui-monospace,monospace;letter-spacing:.05em;margin-bottom:.5rem}
h1{font-size:1.5rem;font-weight:600;color:#f4f4f5;margin-bottom:.5rem}
p{font-size:.95rem;line-height:1.5;color:#9ca3af;margin-bottom:1.75rem}
.domain{display:inline-block;background:#25282d;border:1px solid #32363d;border-radius:6px;padding:.15rem .5rem;font-family:ui-monospace,monospace;font-size:.85rem;color:#d1d5db}
.footer{font-size:.7rem;color:#6b7280;margin-top:2rem;letter-spacing:.03em}
.footer a{color:#6b7280;text-decoration:none}
.footer a:hover{color:#9ca3af}
</style>
</head>
<body>
<div class="c">
<div class="shield"><svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m0-10.036A11.959 11.959 0 0 1 3.598 6 11.99 11.99 0 0 0 3 9.75c0 5.592 3.824 10.29 9 11.622 5.176-1.332 9-6.03 9-11.622 0-1.31-.21-2.571-.598-3.751A11.96 11.96 0 0 0 12 3.714Z"/></svg></div>
<div class="code">403 BLOCKED</div>
<h1>Access Denied</h1>
<p>This connection to <span class="domain">%s</span> has been blocked by your organization's network policy.</p>
<div class="footer">Protected by <a href="https://netbird.io" target="_blank" rel="noopener">NetBird</a></div>
</div>
</body>
</html>`
// singleConnListener is a net.Listener that yields a single connection.
type singleConnListener struct {
conn net.Conn
once sync.Once
ch chan struct{}
}
func (l *singleConnListener) Accept() (net.Conn, error) {
var accepted bool
l.once.Do(func() {
l.ch = make(chan struct{})
accepted = true
})
if accepted {
return l.conn, nil
}
// Block until Close
<-l.ch
return nil, net.ErrClosed
}
func (l *singleConnListener) Close() error {
l.once.Do(func() {
l.ch = make(chan struct{})
})
select {
case <-l.ch:
default:
close(l.ch)
}
return nil
}
func (l *singleConnListener) Addr() net.Addr {
return l.conn.LocalAddr()
}
type readWriter struct {
io.Reader
io.Writer
}
func mergeReadWriter(r io.Reader, w io.Writer) io.ReadWriter {
return &readWriter{Reader: r, Writer: w}
}
// relayRW copies data bidirectionally between two ReadWriters.
func relayRW(ctx context.Context, a, b io.ReadWriter) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errCh := make(chan error, 2)
go func() {
_, err := io.Copy(b, a)
cancel()
errCh <- err
}()
go func() {
_, err := io.Copy(a, b)
cancel()
errCh <- err
}()
var firstErr error
for range 2 {
if err := <-errCh; err != nil && firstErr == nil {
if !isClosedErr(err) {
firstErr = err
}
}
}
return firstErr
}
// hopByHopHeaders are HTTP/1.1 headers that apply to a single connection
// and must not be forwarded by a proxy (RFC 7230, Section 6.1).
var hopByHopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"TE",
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
// removeHopByHopHeaders strips hop-by-hop headers from h.
// Also removes headers listed in the Connection header value.
func removeHopByHopHeaders(h http.Header) {
// First, remove any headers named in the Connection header
for _, connHeader := range h["Connection"] {
for _, name := range strings.Split(connHeader, ",") {
h.Del(strings.TrimSpace(name))
}
}
for _, name := range hopByHopHeaders {
h.Del(name)
}
}

View File

@@ -1,479 +0,0 @@
package inspect
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"net/textproto"
"net/url"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
icapVersion = "ICAP/1.0"
icapDefaultPort = "1344"
icapConnTimeout = 30 * time.Second
icapRWTimeout = 60 * time.Second
icapMaxPoolSize = 8
icapIdleTimeout = 60 * time.Second
icapMaxRespSize = 4 * 1024 * 1024 // 4 MB
)
// ICAPClient implements an ICAP (RFC 3507) client with persistent connection pooling.
type ICAPClient struct {
reqModURL *url.URL
respModURL *url.URL
pool chan *icapConn
mu sync.Mutex
log *log.Entry
maxPool int
}
type icapConn struct {
conn net.Conn
reader *bufio.Reader
lastUse time.Time
}
// NewICAPClient creates an ICAP client. Either or both URLs may be nil
// to disable that mode.
func NewICAPClient(logger *log.Entry, cfg *ICAPConfig) *ICAPClient {
maxPool := cfg.MaxConnections
if maxPool <= 0 {
maxPool = icapMaxPoolSize
}
return &ICAPClient{
reqModURL: cfg.ReqModURL,
respModURL: cfg.RespModURL,
pool: make(chan *icapConn, maxPool),
log: logger,
maxPool: maxPool,
}
}
// ReqMod sends an HTTP request to the ICAP REQMOD service for inspection.
// Returns the (possibly modified) request, or the original if ICAP returns 204.
// Returns nil, nil if REQMOD is not configured.
func (c *ICAPClient) ReqMod(req *http.Request) (*http.Request, error) {
if c.reqModURL == nil {
return req, nil
}
var reqBuf bytes.Buffer
if err := req.Write(&reqBuf); err != nil {
return nil, fmt.Errorf("serialize request: %w", err)
}
respBody, err := c.send("REQMOD", c.reqModURL, reqBuf.Bytes(), nil)
if err != nil {
return nil, err
}
if respBody == nil {
return req, nil
}
modified, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(respBody)))
if err != nil {
return nil, fmt.Errorf("parse ICAP modified request: %w", err)
}
return modified, nil
}
// RespMod sends an HTTP response to the ICAP RESPMOD service for inspection.
// Returns the (possibly modified) response, or the original if ICAP returns 204.
// Returns nil, nil if RESPMOD is not configured.
func (c *ICAPClient) RespMod(req *http.Request, resp *http.Response) (*http.Response, error) {
if c.respModURL == nil {
return resp, nil
}
var reqBuf bytes.Buffer
if err := req.Write(&reqBuf); err != nil {
return nil, fmt.Errorf("serialize request: %w", err)
}
var respBuf bytes.Buffer
if err := resp.Write(&respBuf); err != nil {
return nil, fmt.Errorf("serialize response: %w", err)
}
respBody, err := c.send("RESPMOD", c.respModURL, reqBuf.Bytes(), respBuf.Bytes())
if err != nil {
return nil, err
}
if respBody == nil {
// 204 No Content: ICAP server didn't modify the response.
// Reconstruct from the buffered copy since resp.Body was consumed by Write.
reconstructed, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBuf.Bytes())), req)
if err != nil {
return nil, fmt.Errorf("reconstruct response after ICAP 204: %w", err)
}
return reconstructed, nil
}
modified, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBody)), req)
if err != nil {
return nil, fmt.Errorf("parse ICAP modified response: %w", err)
}
return modified, nil
}
// Close drains and closes all pooled connections.
func (c *ICAPClient) Close() {
close(c.pool)
for ic := range c.pool {
if err := ic.conn.Close(); err != nil {
c.log.Debugf("close ICAP connection: %v", err)
}
}
}
// send executes an ICAP request and returns the encapsulated body from the response.
// Returns nil body for 204 No Content (no modification).
// Retries once on stale pooled connection (EOF on read).
func (c *ICAPClient) send(method string, serviceURL *url.URL, reqData, respData []byte) ([]byte, error) {
statusCode, headers, body, err := c.trySend(method, serviceURL, reqData, respData)
if err != nil && isStaleConnErr(err) {
// Retry once with a fresh connection (stale pool entry).
c.log.Debugf("ICAP %s: retrying after stale connection: %v", method, err)
statusCode, headers, body, err = c.trySend(method, serviceURL, reqData, respData)
}
if err != nil {
return nil, err
}
switch statusCode {
case 204:
return nil, nil
case 200:
return body, nil
default:
c.log.Debugf("ICAP %s returned status %d, headers: %v", method, statusCode, headers)
return nil, fmt.Errorf("ICAP %s: status %d", method, statusCode)
}
}
func (c *ICAPClient) trySend(method string, serviceURL *url.URL, reqData, respData []byte) (int, textproto.MIMEHeader, []byte, error) {
ic, err := c.getConn(serviceURL)
if err != nil {
return 0, nil, nil, fmt.Errorf("get ICAP connection: %w", err)
}
if err := c.writeRequest(ic, method, serviceURL, reqData, respData); err != nil {
if closeErr := ic.conn.Close(); closeErr != nil {
c.log.Debugf("close ICAP conn after write error: %v", closeErr)
}
return 0, nil, nil, fmt.Errorf("write ICAP %s: %w", method, err)
}
statusCode, headers, body, err := c.readResponse(ic)
if err != nil {
if closeErr := ic.conn.Close(); closeErr != nil {
c.log.Debugf("close ICAP conn after read error: %v", closeErr)
}
return 0, nil, nil, fmt.Errorf("read ICAP response: %w", err)
}
c.putConn(ic)
return statusCode, headers, body, nil
}
func isStaleConnErr(err error) bool {
if err == nil {
return false
}
s := err.Error()
return strings.Contains(s, "EOF") || strings.Contains(s, "broken pipe") || strings.Contains(s, "connection reset")
}
func (c *ICAPClient) writeRequest(ic *icapConn, method string, serviceURL *url.URL, reqData, respData []byte) error {
if err := ic.conn.SetWriteDeadline(time.Now().Add(icapRWTimeout)); err != nil {
return fmt.Errorf("set write deadline: %w", err)
}
// For RESPMOD, split the serialized HTTP response into headers and body.
// The body must be sent chunked per RFC 3507.
var respHdr, respBody []byte
if respData != nil {
if idx := bytes.Index(respData, []byte("\r\n\r\n")); idx >= 0 {
respHdr = respData[:idx+4] // include the \r\n\r\n separator
respBody = respData[idx+4:]
} else {
respHdr = respData
}
}
var buf bytes.Buffer
// Request line
fmt.Fprintf(&buf, "%s %s %s\r\n", method, serviceURL.String(), icapVersion)
// Headers
host := serviceURL.Host
fmt.Fprintf(&buf, "Host: %s\r\n", host)
fmt.Fprintf(&buf, "Connection: keep-alive\r\n")
fmt.Fprintf(&buf, "Allow: 204\r\n")
// Build Encapsulated header
offset := 0
var encapParts []string
if reqData != nil {
encapParts = append(encapParts, fmt.Sprintf("req-hdr=%d", offset))
offset += len(reqData)
}
if respHdr != nil {
encapParts = append(encapParts, fmt.Sprintf("res-hdr=%d", offset))
offset += len(respHdr)
}
if len(respBody) > 0 {
encapParts = append(encapParts, fmt.Sprintf("res-body=%d", offset))
} else {
encapParts = append(encapParts, fmt.Sprintf("null-body=%d", offset))
}
fmt.Fprintf(&buf, "Encapsulated: %s\r\n", strings.Join(encapParts, ", "))
fmt.Fprintf(&buf, "\r\n")
// Encapsulated sections
if reqData != nil {
buf.Write(reqData)
}
if respHdr != nil {
buf.Write(respHdr)
}
// Body in chunked encoding (only when there is an actual body section).
// Per RFC 3507 Section 4.4.1, null-body must not include any entity data.
if len(respBody) > 0 {
fmt.Fprintf(&buf, "%x\r\n", len(respBody))
buf.Write(respBody)
buf.WriteString("\r\n")
buf.WriteString("0\r\n\r\n")
}
_, err := ic.conn.Write(buf.Bytes())
return err
}
func (c *ICAPClient) readResponse(ic *icapConn) (int, textproto.MIMEHeader, []byte, error) {
if err := ic.conn.SetReadDeadline(time.Now().Add(icapRWTimeout)); err != nil {
return 0, nil, nil, fmt.Errorf("set read deadline: %w", err)
}
tp := textproto.NewReader(ic.reader)
// Status line: "ICAP/1.0 200 OK"
statusLine, err := tp.ReadLine()
if err != nil {
return 0, nil, nil, fmt.Errorf("read status line: %w", err)
}
statusCode, err := parseICAPStatus(statusLine)
if err != nil {
return 0, nil, nil, err
}
// Headers
headers, err := tp.ReadMIMEHeader()
if err != nil {
return statusCode, nil, nil, fmt.Errorf("read ICAP headers: %w", err)
}
if statusCode == 204 {
return statusCode, headers, nil, nil
}
// Read encapsulated body based on Encapsulated header
body, err := c.readEncapsulatedBody(ic.reader, headers)
if err != nil {
return statusCode, headers, nil, fmt.Errorf("read encapsulated body: %w", err)
}
return statusCode, headers, body, nil
}
func (c *ICAPClient) readEncapsulatedBody(r *bufio.Reader, headers textproto.MIMEHeader) ([]byte, error) {
encap := headers.Get("Encapsulated")
if encap == "" {
return nil, nil
}
// Find the body offset from the Encapsulated header.
// The last section with a non-zero offset is the body.
// Read everything from the reader as the encapsulated content.
var totalSize int
parts := strings.Split(encap, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
eqIdx := strings.Index(part, "=")
if eqIdx < 0 {
continue
}
offset, err := strconv.Atoi(strings.TrimSpace(part[eqIdx+1:]))
if err != nil {
continue
}
if offset > totalSize {
totalSize = offset
}
}
// Read all available encapsulated data (headers + body)
// The body section uses chunked encoding per RFC 3507
var buf bytes.Buffer
if totalSize > 0 {
// Read the header sections (everything before the body offset)
headerBytes := make([]byte, totalSize)
if _, err := io.ReadFull(r, headerBytes); err != nil {
return nil, fmt.Errorf("read encapsulated headers: %w", err)
}
buf.Write(headerBytes)
}
// Read chunked body
chunked := newChunkedReader(r)
body, err := io.ReadAll(io.LimitReader(chunked, icapMaxRespSize))
if err != nil {
return nil, fmt.Errorf("read chunked body: %w", err)
}
buf.Write(body)
return buf.Bytes(), nil
}
func (c *ICAPClient) getConn(serviceURL *url.URL) (*icapConn, error) {
// Try to get a pooled connection
for {
select {
case ic := <-c.pool:
if time.Since(ic.lastUse) > icapIdleTimeout {
if err := ic.conn.Close(); err != nil {
c.log.Debugf("close idle ICAP connection: %v", err)
}
continue
}
return ic, nil
default:
return c.dialConn(serviceURL)
}
}
}
func (c *ICAPClient) putConn(ic *icapConn) {
c.mu.Lock()
defer c.mu.Unlock()
ic.lastUse = time.Now()
select {
case c.pool <- ic:
default:
// Pool full, close connection.
if err := ic.conn.Close(); err != nil {
c.log.Debugf("close excess ICAP connection: %v", err)
}
}
}
func (c *ICAPClient) dialConn(serviceURL *url.URL) (*icapConn, error) {
host := serviceURL.Host
if _, _, err := net.SplitHostPort(host); err != nil {
host = net.JoinHostPort(host, icapDefaultPort)
}
conn, err := net.DialTimeout("tcp", host, icapConnTimeout)
if err != nil {
return nil, fmt.Errorf("dial ICAP %s: %w", host, err)
}
return &icapConn{
conn: conn,
reader: bufio.NewReader(conn),
lastUse: time.Now(),
}, nil
}
func parseICAPStatus(line string) (int, error) {
// "ICAP/1.0 200 OK"
parts := strings.SplitN(line, " ", 3)
if len(parts) < 2 {
return 0, fmt.Errorf("malformed ICAP status line: %q", line)
}
code, err := strconv.Atoi(parts[1])
if err != nil {
return 0, fmt.Errorf("parse ICAP status code %q: %w", parts[1], err)
}
return code, nil
}
// chunkedReader reads ICAP chunked encoding (same as HTTP chunked, terminated by "0\r\n\r\n").
type chunkedReader struct {
r *bufio.Reader
remaining int
done bool
}
func newChunkedReader(r *bufio.Reader) *chunkedReader {
return &chunkedReader{r: r}
}
func (cr *chunkedReader) Read(p []byte) (int, error) {
if cr.done {
return 0, io.EOF
}
if cr.remaining == 0 {
// Read chunk size line
line, err := cr.r.ReadString('\n')
if err != nil {
return 0, err
}
line = strings.TrimSpace(line)
// Strip any chunk extensions
if idx := strings.Index(line, ";"); idx >= 0 {
line = line[:idx]
}
size, err := strconv.ParseInt(line, 16, 64)
if err != nil {
return 0, fmt.Errorf("parse chunk size %q: %w", line, err)
}
if size == 0 {
cr.done = true
// Consume trailing \r\n
_, _ = cr.r.ReadString('\n')
return 0, io.EOF
}
if size < 0 || size > icapMaxRespSize {
return 0, fmt.Errorf("chunk size %d out of range (max %d)", size, icapMaxRespSize)
}
cr.remaining = int(size)
}
toRead := len(p)
if toRead > cr.remaining {
toRead = cr.remaining
}
n, err := cr.r.Read(p[:toRead])
cr.remaining -= n
if cr.remaining == 0 {
// Consume chunk-terminating \r\n
_, _ = cr.r.ReadString('\n')
}
return n, err
}

View File

@@ -1,21 +0,0 @@
//go:build !linux
package inspect
import (
"fmt"
"net"
"net/netip"
log "github.com/sirupsen/logrus"
)
// newTPROXYListener is not supported on non-Linux platforms.
func newTPROXYListener(_ *log.Entry, addr netip.AddrPort, _ netip.Prefix) (net.Listener, error) {
return nil, fmt.Errorf("TPROXY listener not supported on this platform (requested %s)", addr)
}
// getOriginalDst is not supported on non-Linux platforms.
func getOriginalDst(_ net.Conn) (netip.AddrPort, error) {
return netip.AddrPort{}, fmt.Errorf("SO_ORIGINAL_DST not supported on this platform")
}

View File

@@ -1,89 +0,0 @@
package inspect
import (
"fmt"
"net"
"net/netip"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// newTPROXYListener creates a TCP listener for the transparent proxy.
// After nftables REDIRECT, accepted connections have LocalAddr = WG_IP:proxy_port.
// The original destination is retrieved via getsockopt(SO_ORIGINAL_DST).
func newTPROXYListener(logger *log.Entry, addr netip.AddrPort, _ netip.Prefix) (net.Listener, error) {
ln, err := net.Listen("tcp", addr.String())
if err != nil {
return nil, fmt.Errorf("listen on %s: %w", addr, err)
}
logger.Infof("inspect: listener started on %s", ln.Addr())
return ln, nil
}
// getOriginalDst reads the original destination from conntrack via SO_ORIGINAL_DST.
// This is set by the kernel when the connection was REDIRECT'd/DNAT'd.
// Tries IPv4 first, then falls back to IPv6 (IP6T_SO_ORIGINAL_DST).
func getOriginalDst(conn net.Conn) (netip.AddrPort, error) {
tc, ok := conn.(*net.TCPConn)
if !ok {
return netip.AddrPort{}, fmt.Errorf("not a TCPConn")
}
raw, err := tc.SyscallConn()
if err != nil {
return netip.AddrPort{}, fmt.Errorf("get syscall conn: %w", err)
}
var origDst netip.AddrPort
var sockErr error
if err := raw.Control(func(fd uintptr) {
// Try IPv4 first (SO_ORIGINAL_DST = 80)
var sa4 unix.RawSockaddrInet4
sa4Len := uint32(unsafe.Sizeof(sa4))
_, _, errno := unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
unix.SOL_IP,
80, // SO_ORIGINAL_DST
uintptr(unsafe.Pointer(&sa4)),
uintptr(unsafe.Pointer(&sa4Len)),
0,
)
if errno == 0 {
addr := netip.AddrFrom4(sa4.Addr)
port := uint16(sa4.Port>>8) | uint16(sa4.Port<<8)
origDst = netip.AddrPortFrom(addr.Unmap(), port)
return
}
// Fall back to IPv6 (IP6T_SO_ORIGINAL_DST = 80 on SOL_IPV6)
var sa6 unix.RawSockaddrInet6
sa6Len := uint32(unsafe.Sizeof(sa6))
_, _, errno = unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
unix.SOL_IPV6,
80, // IP6T_SO_ORIGINAL_DST
uintptr(unsafe.Pointer(&sa6)),
uintptr(unsafe.Pointer(&sa6Len)),
0,
)
if errno != 0 {
sockErr = fmt.Errorf("getsockopt SO_ORIGINAL_DST (v4 and v6): %w", errno)
return
}
addr := netip.AddrFrom16(sa6.Addr)
port := uint16(sa6.Port>>8) | uint16(sa6.Port<<8)
origDst = netip.AddrPortFrom(addr.Unmap(), port)
}); err != nil {
return netip.AddrPort{}, fmt.Errorf("control raw conn: %w", err)
}
if sockErr != nil {
return netip.AddrPort{}, sockErr
}
return origDst, nil
}

View File

@@ -1,200 +0,0 @@
package inspect
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
mrand "math/rand/v2"
"sync"
"time"
)
const (
// certCacheSize is the maximum number of cached leaf certificates.
certCacheSize = 1024
// certTTL is how long generated certificates remain valid.
certTTL = 24 * time.Hour
)
// certCache is a bounded LRU cache for generated TLS certificates.
type certCache struct {
mu sync.Mutex
entries map[string]*certEntry
// order tracks LRU eviction, most recent at end.
order []string
maxSize int
}
type certEntry struct {
cert *tls.Certificate
expiresAt time.Time
}
func newCertCache(maxSize int) *certCache {
return &certCache{
entries: make(map[string]*certEntry, maxSize),
order: make([]string, 0, maxSize),
maxSize: maxSize,
}
}
func (c *certCache) get(hostname string) (*tls.Certificate, bool) {
c.mu.Lock()
defer c.mu.Unlock()
entry, ok := c.entries[hostname]
if !ok {
return nil, false
}
if time.Now().After(entry.expiresAt) {
c.removeLocked(hostname)
return nil, false
}
// Move to end (most recently used)
c.touchLocked(hostname)
return entry.cert, true
}
func (c *certCache) put(hostname string, cert *tls.Certificate) {
c.mu.Lock()
defer c.mu.Unlock()
// Jitter the TTL by +/- 20% to prevent thundering herd on expiry.
jitter := time.Duration(float64(certTTL) * (0.8 + 0.4*mrand.Float64()))
if _, exists := c.entries[hostname]; exists {
c.entries[hostname] = &certEntry{
cert: cert,
expiresAt: time.Now().Add(jitter),
}
c.touchLocked(hostname)
return
}
// Evict oldest if at capacity
for len(c.entries) >= c.maxSize && len(c.order) > 0 {
c.removeLocked(c.order[0])
}
c.entries[hostname] = &certEntry{
cert: cert,
expiresAt: time.Now().Add(jitter),
}
c.order = append(c.order, hostname)
}
func (c *certCache) touchLocked(hostname string) {
for i, h := range c.order {
if h == hostname {
c.order = append(c.order[:i], c.order[i+1:]...)
c.order = append(c.order, hostname)
return
}
}
}
func (c *certCache) removeLocked(hostname string) {
delete(c.entries, hostname)
for i, h := range c.order {
if h == hostname {
c.order = append(c.order[:i], c.order[i+1:]...)
return
}
}
}
// CertProvider generates TLS certificates on the fly, signed by a CA.
// Generated certificates are cached in an LRU cache.
type CertProvider struct {
ca *x509.Certificate
caKey crypto.PrivateKey
cache *certCache
}
// NewCertProvider creates a certificate provider using the given CA.
func NewCertProvider(ca *x509.Certificate, caKey crypto.PrivateKey) *CertProvider {
return &CertProvider{
ca: ca,
caKey: caKey,
cache: newCertCache(certCacheSize),
}
}
// GetCertificate returns a TLS certificate for the given hostname,
// generating and caching one if necessary.
func (p *CertProvider) GetCertificate(hostname string) (*tls.Certificate, error) {
if cert, ok := p.cache.get(hostname); ok {
return cert, nil
}
cert, err := p.generateCert(hostname)
if err != nil {
return nil, fmt.Errorf("generate cert for %s: %w", hostname, err)
}
p.cache.put(hostname, cert)
return cert, nil
}
// GetTLSConfig returns a tls.Config that dynamically provides certificates
// for any hostname using the MITM CA.
func (p *CertProvider) GetTLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return p.GetCertificate(hello.ServerName)
},
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
}
}
func (p *CertProvider) generateCert(hostname string) (*tls.Certificate, error) {
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial number: %w", err)
}
now := time.Now()
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: hostname,
},
NotBefore: now.Add(-5 * time.Minute),
NotAfter: now.Add(certTTL),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
DNSNames: []string{hostname},
}
leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate leaf key: %w", err)
}
certDER, err := x509.CreateCertificate(rand.Reader, template, p.ca, &leafKey.PublicKey, p.caKey)
if err != nil {
return nil, fmt.Errorf("sign leaf certificate: %w", err)
}
leafCert, err := x509.ParseCertificate(certDER)
if err != nil {
return nil, fmt.Errorf("parse generated certificate: %w", err)
}
return &tls.Certificate{
Certificate: [][]byte{certDER, p.ca.Raw},
PrivateKey: leafKey,
Leaf: leafCert,
}, nil
}

View File

@@ -1,133 +0,0 @@
package inspect
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func generateTestCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "Test CA",
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
return cert, key
}
func TestCertProvider_GetCertificate(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
cert, err := provider.GetCertificate("example.com")
require.NoError(t, err)
require.NotNil(t, cert)
// Verify the leaf certificate
assert.Equal(t, "example.com", cert.Leaf.Subject.CommonName)
assert.Contains(t, cert.Leaf.DNSNames, "example.com")
// Verify chain: leaf + CA
assert.Len(t, cert.Certificate, 2)
// Verify leaf is signed by our CA
pool := x509.NewCertPool()
pool.AddCert(ca)
_, err = cert.Leaf.Verify(x509.VerifyOptions{
Roots: pool,
})
require.NoError(t, err)
}
func TestCertProvider_CachesResults(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
cert1, err := provider.GetCertificate("cached.example.com")
require.NoError(t, err)
cert2, err := provider.GetCertificate("cached.example.com")
require.NoError(t, err)
// Same pointer = cached
assert.Equal(t, cert1, cert2)
}
func TestCertProvider_DifferentHostsDifferentCerts(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
cert1, err := provider.GetCertificate("a.example.com")
require.NoError(t, err)
cert2, err := provider.GetCertificate("b.example.com")
require.NoError(t, err)
assert.NotEqual(t, cert1.Leaf.SerialNumber, cert2.Leaf.SerialNumber)
}
func TestCertProvider_TLSConfigHandshake(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
tlsConfig := provider.GetTLSConfig()
require.NotNil(t, tlsConfig)
require.NotNil(t, tlsConfig.GetCertificate)
// Simulate a ClientHelloInfo
hello := &tls.ClientHelloInfo{
ServerName: "handshake.example.com",
}
cert, err := tlsConfig.GetCertificate(hello)
require.NoError(t, err)
assert.Equal(t, "handshake.example.com", cert.Leaf.Subject.CommonName)
}
func TestCertCache_Eviction(t *testing.T) {
cache := newCertCache(3)
for i := range 5 {
hostname := string(rune('a'+i)) + ".example.com"
cache.put(hostname, &tls.Certificate{})
}
// Only 3 should remain (c, d, e - the most recent)
assert.Len(t, cache.entries, 3)
_, ok := cache.get("a.example.com")
assert.False(t, ok, "oldest entry should be evicted")
_, ok = cache.get("b.example.com")
assert.False(t, ok, "second oldest should be evicted")
_, ok = cache.get("e.example.com")
assert.True(t, ok, "newest entry should exist")
}

View File

@@ -1,109 +0,0 @@
package inspect
import (
"bytes"
"fmt"
"io"
"net"
)
// peekConn wraps a net.Conn with a buffer that allows reading ahead
// without consuming data. Subsequent Read calls return the buffered
// bytes first, then read from the underlying connection.
type peekConn struct {
net.Conn
buf bytes.Buffer
// peeked holds the raw bytes that were peeked, available for replay.
peeked []byte
}
// newPeekConn wraps conn for peek-ahead reading.
func newPeekConn(conn net.Conn) *peekConn {
return &peekConn{Conn: conn}
}
// Peek reads exactly n bytes from the connection without consuming them.
// The peeked bytes are replayed on subsequent Read calls.
// Peek may only be called once; calling it again returns an error.
func (c *peekConn) Peek(n int) ([]byte, error) {
if c.peeked != nil {
return nil, fmt.Errorf("peek already called")
}
buf := make([]byte, n)
if _, err := io.ReadFull(c.Conn, buf); err != nil {
return nil, fmt.Errorf("peek %d bytes: %w", n, err)
}
c.peeked = buf
c.buf.Write(buf)
return buf, nil
}
// PeekAll reads up to n bytes, returning whatever is available.
// Unlike Peek, it does not require exactly n bytes.
func (c *peekConn) PeekAll(n int) ([]byte, error) {
if c.peeked != nil {
return nil, fmt.Errorf("peek already called")
}
buf := make([]byte, n)
nr, err := c.Conn.Read(buf)
if nr > 0 {
c.peeked = buf[:nr]
c.buf.Write(c.peeked)
}
if err != nil && nr == 0 {
return nil, fmt.Errorf("peek: %w", err)
}
return c.peeked, nil
}
// PeekMore extends the peeked buffer to at least n total bytes.
// The buffer is reset and refilled with the extended data.
// The returned slice is the internal peeked buffer; callers must not
// retain references from prior Peek/PeekMore calls after calling this.
func (c *peekConn) PeekMore(n int) ([]byte, error) {
if len(c.peeked) >= n {
return c.peeked[:n], nil
}
remaining := n - len(c.peeked)
extra := make([]byte, remaining)
if _, err := io.ReadFull(c.Conn, extra); err != nil {
return nil, fmt.Errorf("peek more %d bytes: %w", remaining, err)
}
// Pre-allocate to avoid reallocation detaching previously returned slices.
combined := make([]byte, 0, n)
combined = append(combined, c.peeked...)
combined = append(combined, extra...)
c.peeked = combined
c.buf.Reset()
c.buf.Write(c.peeked)
return c.peeked, nil
}
// Peeked returns the bytes that were peeked so far, or nil if Peek hasn't been called.
func (c *peekConn) Peeked() []byte {
return c.peeked
}
// Read returns buffered peek data first, then reads from the underlying connection.
func (c *peekConn) Read(p []byte) (int, error) {
if c.buf.Len() > 0 {
return c.buf.Read(p)
}
return c.Conn.Read(p)
}
// reader returns an io.Reader that replays buffered bytes then reads from conn.
func (c *peekConn) reader() io.Reader {
if c.buf.Len() > 0 {
return io.MultiReader(&c.buf, c.Conn)
}
return c.Conn
}

View File

@@ -1,482 +0,0 @@
package inspect
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// ErrBlocked is returned when a connection is denied by proxy policy.
var ErrBlocked = errors.New("connection blocked by proxy policy")
const (
// headerReadTimeout is the deadline for reading the initial protocol header.
// Prevents slow loris attacks where a client opens a connection but sends data slowly.
headerReadTimeout = 10 * time.Second
// idleTimeout is the deadline for idle connections between HTTP requests.
idleTimeout = 120 * time.Second
)
// Proxy is the inspection engine for traffic passing through a NetBird
// routing peer. It handles protocol detection, rule evaluation, MITM TLS
// decryption, ICAP delegation, and external proxy forwarding.
type Proxy struct {
config Config
rules *RuleEngine
certs *CertProvider
icap *ICAPClient
// envoy is nil unless mode is ModeEnvoy.
envoy *envoyManager
// dialer is the outbound dialer (with SO_MARK cleared on Linux).
dialer net.Dialer
log *log.Entry
// wgNetwork is the WG overlay prefix; dial targets inside it are blocked.
wgNetwork netip.Prefix
// localIPs reports the routing peer's own IPs; dial targets are blocked.
localIPs LocalIPChecker
// listener is the TPROXY/REDIRECT listener for kernel mode.
listener net.Listener
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
// LocalIPChecker reports whether an IP belongs to the local machine.
type LocalIPChecker interface {
IsLocalIP(netip.Addr) bool
}
// New creates a transparent proxy with the given configuration.
func New(ctx context.Context, logger *log.Entry, config Config) (*Proxy, error) {
ctx, cancel := context.WithCancel(ctx)
p := &Proxy{
config: config,
rules: NewRuleEngine(logger, config.DefaultAction),
dialer: newOutboundDialer(),
log: logger,
wgNetwork: config.WGNetwork,
localIPs: config.LocalIPChecker,
ctx: ctx,
cancel: cancel,
}
p.rules.UpdateRules(config.Rules, config.DefaultAction)
// Initialize MITM certificate provider
if config.TLS != nil {
p.certs = NewCertProvider(config.TLS.CA, config.TLS.CAKey)
}
// Initialize ICAP client
if config.ICAP != nil {
p.icap = NewICAPClient(logger, config.ICAP)
}
// Start envoy sidecar if configured
if config.Mode == ModeEnvoy {
envoyLog := logger.WithField("sidecar", "envoy")
em, err := startEnvoy(ctx, envoyLog, config)
if err != nil {
cancel()
return nil, fmt.Errorf("start envoy sidecar: %w", err)
}
p.envoy = em
}
// Start TPROXY listener for kernel mode
if config.ListenAddr.IsValid() {
ln, err := newTPROXYListener(logger, config.ListenAddr, netip.Prefix{})
if err != nil {
cancel()
return nil, fmt.Errorf("start TPROXY listener on %s: %w", config.ListenAddr, err)
}
p.listener = ln
go p.acceptLoop(ln)
}
return p, nil
}
// HandleTCP is the entry point for TCP connections from the userspace forwarder.
// It determines the protocol (TLS or plaintext HTTP), evaluates rules,
// and either blocks, passes through, inspects, or forwards to an external proxy.
func (p *Proxy) HandleTCP(ctx context.Context, clientConn net.Conn, dst netip.AddrPort, src SourceInfo) error {
defer func() {
if err := clientConn.Close(); err != nil {
p.log.Debugf("close client conn: %v", err)
}
}()
p.mu.RLock()
mode := p.config.Mode
p.mu.RUnlock()
if mode == ModeExternal {
pconn := newPeekConn(clientConn)
return p.handleExternal(ctx, pconn, dst)
}
// Envoy and builtin modes both peek the protocol header for rule evaluation.
// Envoy mode forwards non-blocked traffic to envoy; builtin mode handles all locally.
// TLS blocks are handled by Go (instant close) since envoy can't cleanly RST a TLS connection.
// Built-in and envoy mode: peek 5 bytes (TLS record header size) to determine protocol.
// Set a read deadline to prevent slow loris attacks.
if err := clientConn.SetReadDeadline(time.Now().Add(headerReadTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
pconn := newPeekConn(clientConn)
header, err := pconn.Peek(5)
if err != nil {
return fmt.Errorf("peek protocol header: %w", err)
}
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("clear read deadline: %w", err)
}
if isTLSHandshake(header[0]) {
return p.handleTLS(ctx, pconn, dst, src)
}
if isHTTPMethod(header) {
return p.handlePlainHTTP(ctx, pconn, dst, src)
}
// Not TLS and not HTTP: evaluate rules with ProtoOther.
// If no rule explicitly allows "other", this falls through to the default action.
action := p.rules.Evaluate(src.IP, "", dst.Addr(), dst.Port(), ProtoOther, "")
if action == ActionAllow {
remote, err := p.dialTCP(ctx, dst)
if err != nil {
return fmt.Errorf("dial for passthrough: %w", err)
}
defer func() {
if err := remote.Close(); err != nil {
p.log.Debugf("close remote conn: %v", err)
}
}()
return relay(ctx, pconn, remote)
}
p.log.Debugf("block: non-HTTP/TLS to %s (action=%s, first bytes: %x)", dst, action, header)
return ErrBlocked
}
// InspectTCP evaluates rules for a TCP connection and returns the result.
// Unlike HandleTCP, it can return early for allow decisions, letting the caller
// handle the relay (USP forwarder passthrough optimization).
//
// When InspectResult.PassthroughConn is non-nil, ownership transfers to the caller:
// the caller must close the connection and relay traffic. The engine does not close it.
//
// When PassthroughConn is nil, the engine handled everything internally
// (block, inspect/MITM, or plain HTTP inspection) and closed the connection.
func (p *Proxy) InspectTCP(ctx context.Context, clientConn net.Conn, dst netip.AddrPort, src SourceInfo) (InspectResult, error) {
p.mu.RLock()
mode := p.config.Mode
envoy := p.envoy
p.mu.RUnlock()
// External mode: handle internally, engine owns the connection.
if mode == ModeExternal {
defer func() {
if err := clientConn.Close(); err != nil {
p.log.Debugf("close client conn: %v", err)
}
}()
pconn := newPeekConn(clientConn)
err := p.handleExternal(ctx, pconn, dst)
return InspectResult{Action: ActionAllow}, err
}
// Peek protocol header.
if err := clientConn.SetReadDeadline(time.Now().Add(headerReadTimeout)); err != nil {
clientConn.Close()
return InspectResult{}, fmt.Errorf("set read deadline: %w", err)
}
pconn := newPeekConn(clientConn)
header, err := pconn.Peek(5)
if err != nil {
clientConn.Close()
return InspectResult{}, fmt.Errorf("peek protocol header: %w", err)
}
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
clientConn.Close()
return InspectResult{}, fmt.Errorf("clear read deadline: %w", err)
}
// TLS: may return passthrough for allow.
if isTLSHandshake(header[0]) {
result, err := p.inspectTLS(ctx, pconn, dst, src)
if err != nil && result.PassthroughConn == nil {
clientConn.Close()
return result, err
}
// Envoy mode: forward allowed TLS to envoy instead of returning passthrough.
if result.PassthroughConn != nil && envoy != nil {
defer clientConn.Close()
envoyErr := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
return InspectResult{Action: ActionAllow}, envoyErr
}
return result, err
}
// Plain HTTP: in envoy mode, forward to envoy for L7 processing.
// In builtin mode, inspect per-request locally.
if isHTTPMethod(header) {
defer func() {
if err := clientConn.Close(); err != nil {
p.log.Debugf("close client conn: %v", err)
}
}()
if envoy != nil {
err := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
return InspectResult{Action: ActionAllow}, err
}
err := p.handlePlainHTTP(ctx, pconn, dst, src)
return InspectResult{Action: ActionInspect}, err
}
// Other protocol: evaluate rules.
action := p.rules.Evaluate(src.IP, "", dst.Addr(), dst.Port(), ProtoOther, "")
if action == ActionAllow {
// Envoy mode: forward to envoy.
if envoy != nil {
defer clientConn.Close()
err := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
return InspectResult{Action: ActionAllow}, err
}
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
}
p.log.Debugf("block: non-HTTP/TLS to %s (action=%s, first bytes: %x)", dst, action, header)
clientConn.Close()
return InspectResult{Action: ActionBlock}, ErrBlocked
}
// HandleUDPPacket inspects a UDP packet for QUIC Initial packets.
// Returns the action to take: ActionAllow to continue normal forwarding,
// ActionBlock to drop the packet.
// Non-QUIC packets always return ActionAllow.
func (p *Proxy) HandleUDPPacket(data []byte, dst netip.AddrPort, src SourceInfo) Action {
if len(data) < 5 {
return ActionAllow
}
// Check for QUIC Long Header
if data[0]&0x80 == 0 {
return ActionAllow
}
sni, err := ExtractQUICSNI(data)
if err != nil {
// Can't parse QUIC, allow through (could be non-QUIC UDP)
p.log.Tracef("QUIC SNI extraction failed for %s: %v", dst, err)
return ActionAllow
}
if sni == "" {
return ActionAllow
}
action := p.rules.Evaluate(src.IP, sni, dst.Addr(), dst.Port(), ProtoH3, "")
if action == ActionBlock {
p.log.Debugf("block: QUIC to %s (SNI=%s)", dst, sni.PunycodeString())
return ActionBlock
}
// QUIC can't be MITMed, treat Inspect as Allow
if action == ActionInspect {
p.log.Debugf("allow: QUIC to %s (SNI=%s), MITM not supported for QUIC", dst, sni.PunycodeString())
} else {
p.log.Tracef("allow: QUIC to %s (SNI=%s)", dst, sni.PunycodeString())
}
return ActionAllow
}
// handlePlainHTTP handles plaintext HTTP connections.
func (p *Proxy) handlePlainHTTP(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) error {
remote, err := p.dialTCP(ctx, dst)
if err != nil {
return fmt.Errorf("dial %s: %w", dst, err)
}
defer func() {
if err := remote.Close(); err != nil {
p.log.Debugf("close remote for %s: %v", dst, err)
}
}()
// For plaintext HTTP, always inspect (we can see the traffic)
return p.inspectHTTP(ctx, pconn, remote, dst, "", src, "http/1.1")
}
// UpdateConfig replaces the inspection engine configuration at runtime.
func (p *Proxy) UpdateConfig(config Config) {
p.log.Debugf("config update: mode=%s rules=%d default=%s has_tls=%v has_icap=%v",
config.Mode, len(config.Rules), config.DefaultAction, config.TLS != nil, config.ICAP != nil)
p.mu.Lock()
p.config = config
p.rules.UpdateRules(config.Rules, config.DefaultAction)
// Update MITM provider
if config.TLS != nil {
p.certs = NewCertProvider(config.TLS.CA, config.TLS.CAKey)
} else {
p.certs = nil
}
// Swap ICAP client under lock, close the old one outside to avoid blocking.
var oldICAP *ICAPClient
if config.ICAP != nil {
oldICAP = p.icap
p.icap = NewICAPClient(p.log, config.ICAP)
} else {
oldICAP = p.icap
p.icap = nil
}
// If switching away from envoy mode, clear and stop the old envoy.
var oldEnvoy *envoyManager
if config.Mode != ModeEnvoy && p.envoy != nil {
oldEnvoy = p.envoy
p.envoy = nil
}
envoy := p.envoy
p.mu.Unlock()
if oldICAP != nil {
oldICAP.Close()
}
if oldEnvoy != nil {
oldEnvoy.Stop()
}
// Reload envoy config if still in envoy mode.
if envoy != nil && config.Mode == ModeEnvoy {
if err := envoy.Reload(config); err != nil {
p.log.Errorf("inspect: envoy config reload: %v", err)
}
}
}
// Mode returns the current proxy operating mode.
func (p *Proxy) Mode() ProxyMode {
p.mu.RLock()
defer p.mu.RUnlock()
return p.config.Mode
}
// ListenPort returns the port to use for kernel-mode nftables REDIRECT.
// For builtin mode: the TPROXY listener port.
// For envoy mode: the envoy listener port (nftables redirects directly to envoy).
// Returns 0 if no listener is active.
func (p *Proxy) ListenPort() uint16 {
p.mu.RLock()
envoy := p.envoy
p.mu.RUnlock()
if envoy != nil {
return envoy.listenPort
}
if p.listener == nil {
return 0
}
tcpAddr, ok := p.listener.Addr().(*net.TCPAddr)
if !ok {
return 0
}
return uint16(tcpAddr.Port)
}
// Close shuts down the proxy and releases resources.
func (p *Proxy) Close() error {
p.cancel()
p.mu.Lock()
envoy := p.envoy
p.envoy = nil
icap := p.icap
p.icap = nil
p.mu.Unlock()
if envoy != nil {
envoy.Stop()
}
if p.listener != nil {
if err := p.listener.Close(); err != nil {
p.log.Debugf("close TPROXY listener: %v", err)
}
}
if icap != nil {
icap.Close()
}
return nil
}
// acceptLoop accepts connections from the redirected listener (kernel mode).
// Connections arrive via nftables REDIRECT; original destination is read from conntrack.
func (p *Proxy) acceptLoop(ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
if p.ctx.Err() != nil {
return
}
p.log.Debugf("accept error: %v", err)
continue
}
go func() {
// Read original destination from conntrack (SO_ORIGINAL_DST).
// nftables REDIRECT changes dst to the local WG IP:proxy_port,
// but conntrack preserves the real destination.
dstAddr, err := getOriginalDst(conn)
if err != nil {
p.log.Debugf("get original dst: %v", err)
if closeErr := conn.Close(); closeErr != nil {
p.log.Debugf("close conn: %v", closeErr)
}
return
}
p.log.Tracef("accepted: %s -> %s (original dst %s)",
conn.RemoteAddr(), conn.LocalAddr(), dstAddr)
srcAddr, err := netip.ParseAddrPort(conn.RemoteAddr().String())
if err != nil {
p.log.Debugf("parse source: %v", err)
if closeErr := conn.Close(); closeErr != nil {
p.log.Debugf("close conn: %v", closeErr)
}
return
}
src := SourceInfo{
IP: srcAddr.Addr().Unmap(),
}
if err := p.HandleTCP(p.ctx, conn, dstAddr, src); err != nil && !errors.Is(err, ErrBlocked) {
p.log.Debugf("connection to %s: %v", dstAddr, err)
}
}()
}
}

View File

@@ -1,388 +0,0 @@
package inspect
import (
"crypto/aes"
"crypto/cipher"
"crypto/sha256"
"encoding/binary"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
"github.com/netbirdio/netbird/shared/management/domain"
)
// QUIC version constants
const (
quicV1Version uint32 = 0x00000001
quicV2Version uint32 = 0x6b3343cf
)
// quicV1Salt is the initial salt for QUIC v1 (RFC 9001 Section 5.2).
var quicV1Salt = []byte{
0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3,
0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad,
0xcc, 0xbb, 0x7f, 0x0a,
}
// quicV2Salt is the initial salt for QUIC v2 (RFC 9369).
var quicV2Salt = []byte{
0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb,
0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb,
0xf9, 0xbd, 0x2e, 0xd9,
}
// ExtractQUICSNI extracts the SNI from a QUIC Initial packet.
// The Initial packet's encryption uses well-known keys derived from the
// Destination Connection ID, so any observer can decrypt it (by design).
func ExtractQUICSNI(data []byte) (domain.Domain, error) {
if len(data) < 5 {
return "", fmt.Errorf("packet too short")
}
// Check for QUIC Long Header (form bit set)
if data[0]&0x80 == 0 {
return "", fmt.Errorf("not a QUIC long header packet")
}
// Version
version := binary.BigEndian.Uint32(data[1:5])
var salt []byte
var initialLabel, keyLabel, ivLabel, hpLabel string
switch version {
case quicV1Version:
salt = quicV1Salt
initialLabel = "client in"
keyLabel = "quic key"
ivLabel = "quic iv"
hpLabel = "quic hp"
case quicV2Version:
salt = quicV2Salt
initialLabel = "client in"
keyLabel = "quicv2 key"
ivLabel = "quicv2 iv"
hpLabel = "quicv2 hp"
default:
return "", fmt.Errorf("unsupported QUIC version: 0x%08x", version)
}
// Parse Long Header
if len(data) < 6 {
return "", fmt.Errorf("packet too short for DCID length")
}
dcidLen := int(data[5])
if len(data) < 6+dcidLen+1 {
return "", fmt.Errorf("packet too short for DCID")
}
dcid := data[6 : 6+dcidLen]
scidLenOff := 6 + dcidLen
scidLen := int(data[scidLenOff])
tokenLenOff := scidLenOff + 1 + scidLen
if tokenLenOff >= len(data) {
return "", fmt.Errorf("packet too short for token length")
}
// Token length is a variable-length integer
tokenLen, tokenLenSize, err := readVarInt(data[tokenLenOff:])
if err != nil {
return "", fmt.Errorf("read token length: %w", err)
}
payloadLenOff := tokenLenOff + tokenLenSize + int(tokenLen)
if payloadLenOff >= len(data) {
return "", fmt.Errorf("packet too short for payload length")
}
// Payload length is a variable-length integer
payloadLen, payloadLenSize, err := readVarInt(data[payloadLenOff:])
if err != nil {
return "", fmt.Errorf("read payload length: %w", err)
}
pnOffset := payloadLenOff + payloadLenSize
if pnOffset+4 > len(data) {
return "", fmt.Errorf("packet too short for packet number")
}
// Derive initial keys
clientKey, clientIV, clientHP, err := deriveInitialKeys(dcid, salt, initialLabel, keyLabel, ivLabel, hpLabel)
if err != nil {
return "", fmt.Errorf("derive initial keys: %w", err)
}
// Remove header protection
sampleOffset := pnOffset + 4 // sample starts 4 bytes after pn offset
if sampleOffset+16 > len(data) {
return "", fmt.Errorf("packet too short for HP sample")
}
sample := data[sampleOffset : sampleOffset+16]
hpBlock, err := aes.NewCipher(clientHP)
if err != nil {
return "", fmt.Errorf("create HP cipher: %w", err)
}
mask := make([]byte, 16)
hpBlock.Encrypt(mask, sample)
// Unmask header byte
header := make([]byte, len(data))
copy(header, data)
header[0] ^= mask[0] & 0x0f // Long header: low 4 bits
// Determine packet number length
pnLen := int(header[0]&0x03) + 1
// Unmask packet number
for i := 0; i < pnLen; i++ {
header[pnOffset+i] ^= mask[1+i]
}
// Reconstruct packet number
var pn uint32
for i := 0; i < pnLen; i++ {
pn = (pn << 8) | uint32(header[pnOffset+i])
}
// Build nonce
nonce := make([]byte, len(clientIV))
copy(nonce, clientIV)
for i := 0; i < 4; i++ {
nonce[len(nonce)-1-i] ^= byte(pn >> (8 * i))
}
// Decrypt payload
block, err := aes.NewCipher(clientKey)
if err != nil {
return "", fmt.Errorf("create AES cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create AEAD: %w", err)
}
encryptedPayload := header[pnOffset+pnLen : pnOffset+int(payloadLen)]
aad := header[:pnOffset+pnLen]
plaintext, err := aead.Open(nil, nonce, encryptedPayload, aad)
if err != nil {
return "", fmt.Errorf("decrypt QUIC payload: %w", err)
}
// Parse CRYPTO frames to extract ClientHello
clientHello, err := extractCryptoFrames(plaintext)
if err != nil {
return "", fmt.Errorf("extract CRYPTO frames: %w", err)
}
info, err := parseHelloBody(clientHello)
return info.SNI, err
}
// deriveInitialKeys derives the client's initial encryption keys from the DCID.
func deriveInitialKeys(dcid, salt []byte, initialLabel, keyLabel, ivLabel, hpLabel string) (key, iv, hp []byte, err error) {
// initial_secret = HKDF-Extract(salt, DCID)
initialSecret := hkdf.Extract(sha256.New, dcid, salt)
// client_initial_secret = HKDF-Expand-Label(initial_secret, initialLabel, "", 32)
clientSecret, err := hkdfExpandLabel(initialSecret, initialLabel, nil, 32)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive client secret: %w", err)
}
// client_key = HKDF-Expand-Label(client_secret, keyLabel, "", 16)
key, err = hkdfExpandLabel(clientSecret, keyLabel, nil, 16)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive key: %w", err)
}
// client_iv = HKDF-Expand-Label(client_secret, ivLabel, "", 12)
iv, err = hkdfExpandLabel(clientSecret, ivLabel, nil, 12)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive IV: %w", err)
}
// client_hp = HKDF-Expand-Label(client_secret, hpLabel, "", 16)
hp, err = hkdfExpandLabel(clientSecret, hpLabel, nil, 16)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive HP key: %w", err)
}
return key, iv, hp, nil
}
// hkdfExpandLabel implements TLS 1.3 HKDF-Expand-Label.
func hkdfExpandLabel(secret []byte, label string, context []byte, length int) ([]byte, error) {
// HkdfLabel = struct {
// uint16 length;
// opaque label<7..255> = "tls13 " + Label;
// opaque context<0..255> = Context;
// }
fullLabel := "tls13 " + label
hkdfLabel := make([]byte, 2+1+len(fullLabel)+1+len(context))
binary.BigEndian.PutUint16(hkdfLabel[0:2], uint16(length))
hkdfLabel[2] = byte(len(fullLabel))
copy(hkdfLabel[3:], fullLabel)
hkdfLabel[3+len(fullLabel)] = byte(len(context))
if len(context) > 0 {
copy(hkdfLabel[4+len(fullLabel):], context)
}
expander := hkdf.Expand(sha256.New, secret, hkdfLabel)
out := make([]byte, length)
if _, err := io.ReadFull(expander, out); err != nil {
return nil, err
}
return out, nil
}
// maxCryptoFrameSize limits total CRYPTO frame data to prevent memory exhaustion.
const maxCryptoFrameSize = 64 * 1024
// extractCryptoFrames reassembles CRYPTO frame data from QUIC frames.
func extractCryptoFrames(frames []byte) ([]byte, error) {
var result []byte
pos := 0
for pos < len(frames) {
frameType := frames[pos]
switch {
case frameType == 0x00:
// PADDING frame
pos++
case frameType == 0x06:
// CRYPTO frame
pos++
offset, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read crypto offset: %w", err)
}
pos += n
_ = offset // We assume ordered, offset 0 for Initial
dataLen, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read crypto data length: %w", err)
}
pos += n
end := pos + int(dataLen)
if end > len(frames) {
return nil, fmt.Errorf("CRYPTO frame data truncated")
}
result = append(result, frames[pos:end]...)
if len(result) > maxCryptoFrameSize {
return nil, fmt.Errorf("CRYPTO frame data exceeds %d bytes", maxCryptoFrameSize)
}
pos = end
case frameType == 0x01:
// PING frame
pos++
case frameType == 0x02 || frameType == 0x03:
// ACK frame - skip
pos++
// Largest Acknowledged
_, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK: %w", err)
}
pos += n
// ACK Delay
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK delay: %w", err)
}
pos += n
// ACK Range Count
rangeCount, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK range count: %w", err)
}
pos += n
// First ACK Range
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read first ACK range: %w", err)
}
pos += n
// Additional ranges
for i := uint64(0); i < rangeCount; i++ {
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK gap: %w", err)
}
pos += n
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK range: %w", err)
}
pos += n
}
// ECN counts for type 0x03
if frameType == 0x03 {
for range 3 {
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ECN count: %w", err)
}
pos += n
}
}
default:
// Unknown frame type, stop parsing
if len(result) > 0 {
return result, nil
}
return nil, fmt.Errorf("unknown QUIC frame type: 0x%02x at offset %d", frameType, pos)
}
}
if len(result) == 0 {
return nil, fmt.Errorf("no CRYPTO frames found")
}
return result, nil
}
// readVarInt reads a QUIC variable-length integer.
// Returns (value, bytes consumed, error).
func readVarInt(data []byte) (uint64, int, error) {
if len(data) == 0 {
return 0, 0, fmt.Errorf("empty data for varint")
}
prefix := data[0] >> 6
length := 1 << prefix
if len(data) < length {
return 0, 0, fmt.Errorf("varint truncated: need %d, have %d", length, len(data))
}
var val uint64
switch length {
case 1:
val = uint64(data[0] & 0x3f)
case 2:
val = uint64(binary.BigEndian.Uint16(data[:2])) & 0x3fff
case 4:
val = uint64(binary.BigEndian.Uint32(data[:4])) & 0x3fffffff
case 8:
val = binary.BigEndian.Uint64(data[:8]) & 0x3fffffffffffffff
}
return val, length, nil
}

View File

@@ -1,99 +0,0 @@
package inspect
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReadVarInt(t *testing.T) {
tests := []struct {
name string
data []byte
want uint64
n int
}{
{
name: "1 byte value",
data: []byte{0x25},
want: 37,
n: 1,
},
{
name: "2 byte value",
data: []byte{0x7b, 0xbd},
want: 15293,
n: 2,
},
{
name: "4 byte value",
data: []byte{0x9d, 0x7f, 0x3e, 0x7d},
want: 494878333,
n: 4,
},
{
name: "zero",
data: []byte{0x00},
want: 0,
n: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, n, err := readVarInt(tt.data)
require.NoError(t, err)
assert.Equal(t, tt.want, val)
assert.Equal(t, tt.n, n)
})
}
}
func TestReadVarInt_Empty(t *testing.T) {
_, _, err := readVarInt(nil)
require.Error(t, err)
}
func TestReadVarInt_Truncated(t *testing.T) {
// 2-byte prefix but only 1 byte
_, _, err := readVarInt([]byte{0x40})
require.Error(t, err)
}
func TestExtractQUICSNI_NotLongHeader(t *testing.T) {
// Short header packet (form bit not set)
data := make([]byte, 100)
data[0] = 0x40 // short header
_, err := ExtractQUICSNI(data)
require.Error(t, err)
assert.Contains(t, err.Error(), "not a QUIC long header")
}
func TestExtractQUICSNI_UnsupportedVersion(t *testing.T) {
data := make([]byte, 100)
data[0] = 0xC0 // long header
// Version 0xdeadbeef
data[1] = 0xde
data[2] = 0xad
data[3] = 0xbe
data[4] = 0xef
_, err := ExtractQUICSNI(data)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported QUIC version")
}
func TestExtractQUICSNI_TooShort(t *testing.T) {
_, err := ExtractQUICSNI([]byte{0xC0, 0x00})
require.Error(t, err)
}
func TestHkdfExpandLabel(t *testing.T) {
// Smoke test: ensure it returns the right length and doesn't error
secret := make([]byte, 32)
result, err := hkdfExpandLabel(secret, "quic key", nil, 16)
require.NoError(t, err)
assert.Len(t, result, 16)
}

View File

@@ -1,253 +0,0 @@
package inspect
import (
"net/netip"
"slices"
"sort"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
// RuleEngine evaluates proxy rules against connection metadata.
// It is safe for concurrent use.
type RuleEngine struct {
mu sync.RWMutex
rules []Rule
// defaultAction applies when no rule matches.
defaultAction Action
log *log.Entry
}
// NewRuleEngine creates a rule engine with the given default action.
func NewRuleEngine(logger *log.Entry, defaultAction Action) *RuleEngine {
return &RuleEngine{
defaultAction: defaultAction,
log: logger,
}
}
// UpdateRules replaces the rule set and default action. Rules are sorted by priority.
func (e *RuleEngine) UpdateRules(rules []Rule, defaultAction Action) {
sorted := make([]Rule, len(rules))
copy(sorted, rules)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Priority < sorted[j].Priority
})
e.mu.Lock()
e.rules = sorted
e.defaultAction = defaultAction
e.mu.Unlock()
}
// EvalResult holds the outcome of a rule evaluation.
type EvalResult struct {
Action Action
RuleID id.RuleID
}
// Evaluate determines the action for a connection based on the rule set.
// Pass empty path for connection-level evaluation (TLS/SNI), non-empty for request-level (HTTP).
func (e *RuleEngine) Evaluate(src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) Action {
r := e.EvaluateWithResult(src, dstDomain, dstAddr, dstPort, proto, path)
return r.Action
}
// EvaluateWithResult is like Evaluate but also returns the matched rule ID.
func (e *RuleEngine) EvaluateWithResult(src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) EvalResult {
e.mu.RLock()
defer e.mu.RUnlock()
for i := range e.rules {
rule := &e.rules[i]
if e.ruleMatches(rule, src, dstDomain, dstAddr, dstPort, proto, path) {
e.log.Tracef("rule %s matched: action=%s src=%s domain=%s dst=%s:%d proto=%s path=%s",
rule.ID, rule.Action, src, dstDomain.SafeString(), dstAddr, dstPort, proto, path)
return EvalResult{Action: rule.Action, RuleID: rule.ID}
}
}
e.log.Tracef("no rule matched, default=%s: src=%s domain=%s dst=%s:%d proto=%s path=%s",
e.defaultAction, src, dstDomain.SafeString(), dstAddr, dstPort, proto, path)
return EvalResult{Action: e.defaultAction}
}
// HasPathRulesForDomain returns true if any rule matching the domain has non-empty Paths.
// Used to force MITM inspection when path-level rules exist (paths are only visible after decryption).
func (e *RuleEngine) HasPathRulesForDomain(dstDomain domain.Domain) bool {
e.mu.RLock()
defer e.mu.RUnlock()
for i := range e.rules {
if len(e.rules[i].Paths) > 0 && e.matchDomain(&e.rules[i], dstDomain) {
return true
}
}
return false
}
// ruleMatches checks whether all non-empty fields of a rule match.
// Empty fields are treated as "match any".
// All specified fields must match (AND logic).
func (e *RuleEngine) ruleMatches(rule *Rule, src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) bool {
if !e.matchSource(rule, src) {
return false
}
if !e.matchDomain(rule, dstDomain) {
return false
}
if !e.matchNetwork(rule, dstAddr) {
return false
}
if !e.matchPort(rule, dstPort) {
return false
}
if !e.matchProtocol(rule, proto) {
return false
}
if !e.matchPaths(rule, path) {
return false
}
return true
}
// matchSource returns true if src matches any of the rule's source CIDRs,
// or if no source CIDRs are specified (match any).
func (e *RuleEngine) matchSource(rule *Rule, src netip.Addr) bool {
if len(rule.Sources) == 0 {
return true
}
for _, prefix := range rule.Sources {
if prefix.Contains(src) {
return true
}
}
return false
}
// matchDomain returns true if dstDomain matches any of the rule's domain patterns,
// or if no domain patterns are specified (match any).
func (e *RuleEngine) matchDomain(rule *Rule, dstDomain domain.Domain) bool {
if len(rule.Domains) == 0 {
return true
}
// If we have domain rules but no domain to match against (e.g., raw IP connection),
// the domain condition does not match.
if dstDomain == "" {
return false
}
for _, pattern := range rule.Domains {
if MatchDomain(pattern, dstDomain) {
return true
}
}
return false
}
// matchNetwork returns true if dstAddr is within any of the rule's destination CIDRs,
// or if no destination CIDRs are specified (match any).
func (e *RuleEngine) matchNetwork(rule *Rule, dstAddr netip.Addr) bool {
if len(rule.Networks) == 0 {
return true
}
for _, prefix := range rule.Networks {
if prefix.Contains(dstAddr) {
return true
}
}
return false
}
// matchProtocol returns true if proto matches any of the rule's protocols,
// or if no protocols are specified (match any).
func (e *RuleEngine) matchProtocol(rule *Rule, proto ProtoType) bool {
if len(rule.Protocols) == 0 {
return true
}
for _, p := range rule.Protocols {
if p == proto {
return true
}
}
return false
}
// matchPort returns true if dstPort matches any of the rule's destination ports,
// or if no ports are specified (match any).
func (e *RuleEngine) matchPort(rule *Rule, dstPort uint16) bool {
if len(rule.Ports) == 0 {
return true
}
return slices.Contains(rule.Ports, dstPort)
}
// matchPaths returns true if path matches any of the rule's path patterns,
// or if no paths are specified (match any). Empty path (connection-level eval) matches all.
func (e *RuleEngine) matchPaths(rule *Rule, path string) bool {
if len(rule.Paths) == 0 {
return true
}
// Connection-level (path=""): rules with paths don't match at connection level.
// HasPathRulesForDomain forces the connection to inspect, so paths are
// checked per-request once the HTTP request is visible.
if path == "" {
return false
}
for _, pattern := range rule.Paths {
if matchPath(pattern, path) {
return true
}
}
return false
}
// matchPath checks if a URL path matches a pattern.
// Supports: exact ("/login"), prefix with wildcard ("/api/*"),
// and contains ("*/admin/*"). A bare "*" matches everything.
func matchPath(pattern, path string) bool {
if pattern == "*" {
return true
}
hasLeadingStar := strings.HasPrefix(pattern, "*")
hasTrailingStar := strings.HasSuffix(pattern, "*")
switch {
case hasLeadingStar && hasTrailingStar:
// */admin/* = contains
middle := strings.Trim(pattern, "*")
return strings.Contains(path, middle)
case hasTrailingStar:
// /api/* = prefix
prefix := strings.TrimSuffix(pattern, "*")
return strings.HasPrefix(path, prefix)
case hasLeadingStar:
// *.json = suffix
suffix := strings.TrimPrefix(pattern, "*")
return strings.HasSuffix(path, suffix)
default:
// exact
return path == pattern
}
}

View File

@@ -1,338 +0,0 @@
package inspect
import (
"net/netip"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
func testLogger() *log.Entry {
return log.WithField("test", true)
}
func mustDomain(t *testing.T, s string) domain.Domain {
t.Helper()
d, err := domain.FromString(s)
require.NoError(t, err)
return d
}
func TestRuleEngine_Evaluate(t *testing.T) {
tests := []struct {
name string
rules []Rule
defaultAction Action
src netip.Addr
dstDomain domain.Domain
dstAddr netip.Addr
dstPort uint16
want Action
}{
{
name: "no rules returns default allow",
defaultAction: ActionAllow,
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "no rules returns default block",
defaultAction: ActionBlock,
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain exact match blocks",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "malware.example.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "malware.example.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain wildcard match blocks",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "phishing.evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain wildcard does not match base",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "case insensitive domain match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "Example.COM")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "EXAMPLE.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "source CIDR match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("192.168.1.50"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionInspect,
},
{
name: "source CIDR no match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.5"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "destination network match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("192.168.1.1"),
dstAddr: netip.MustParseAddr("10.50.0.1"),
dstPort: 80,
want: ActionInspect,
},
{
name: "port match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Ports: []uint16{443, 8443},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionInspect,
},
{
name: "port no match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Ports: []uint16{443, 8443},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 22,
want: ActionAllow,
},
{
name: "priority ordering first match wins",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("allow-internal"),
Domains: []domain.Domain{mustDomain(t, "*.internal.corp")},
Action: ActionAllow,
Priority: 1,
},
{
ID: id.RuleID("inspect-all"),
Action: ActionInspect,
Priority: 10,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "api.internal.corp"),
dstAddr: netip.MustParseAddr("10.1.0.5"),
dstPort: 443,
want: ActionAllow,
},
{
name: "all fields must match (AND logic)",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Ports: []uint16{443},
Action: ActionBlock,
},
},
// Source matches, domain matches, but port doesn't
src: netip.MustParseAddr("192.168.1.10"),
dstDomain: mustDomain(t, "phish.evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 8080,
want: ActionAllow,
},
{
name: "empty domain with domain rule does not match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "example.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: "", // raw IP connection, no SNI
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine := NewRuleEngine(testLogger(), tt.defaultAction)
engine.UpdateRules(tt.rules, tt.defaultAction)
got := engine.Evaluate(tt.src, tt.dstDomain, tt.dstAddr, tt.dstPort, "", "")
assert.Equal(t, tt.want, got)
})
}
}
func TestRuleEngine_ProtocolMatching(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{
ID: "block-websocket",
Protocols: []ProtoType{ProtoWebSocket},
Action: ActionBlock,
Priority: 1,
},
{
ID: "inspect-h2",
Protocols: []ProtoType{ProtoH2},
Action: ActionInspect,
Priority: 2,
},
}, ActionAllow)
src := netip.MustParseAddr("10.0.0.1")
dst := netip.MustParseAddr("1.2.3.4")
// WebSocket: blocked by rule
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
// HTTP/2: inspected by rule
assert.Equal(t, ActionInspect, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
// Plain HTTP: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 80, ProtoHTTP, ""))
// HTTPS: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
// QUIC/H3: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoH3, ""))
// Empty protocol (unknown): no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, "", ""))
}
func TestRuleEngine_EmptyProtocolsMatchAll(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{
ID: "block-all-protos",
Action: ActionBlock,
// No Protocols field = match all protocols
Priority: 1,
},
}, ActionAllow)
src := netip.MustParseAddr("10.0.0.1")
dst := netip.MustParseAddr("1.2.3.4")
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTP, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, "", ""))
}
func TestRuleEngine_UpdateRulesSortsByPriority(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{ID: "c", Priority: 30, Action: ActionBlock},
{ID: "a", Priority: 10, Action: ActionInspect},
{ID: "b", Priority: 20, Action: ActionAllow},
}, ActionAllow)
engine.mu.RLock()
defer engine.mu.RUnlock()
require.Len(t, engine.rules, 3)
assert.Equal(t, id.RuleID("a"), engine.rules[0].ID)
assert.Equal(t, id.RuleID("b"), engine.rules[1].ID)
assert.Equal(t, id.RuleID("c"), engine.rules[2].ID)
}

View File

@@ -1,287 +0,0 @@
package inspect
import (
"encoding/binary"
"fmt"
"io"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
recordTypeHandshake = 0x16
handshakeTypeClientHello = 0x01
extensionTypeSNI = 0x0000
extensionTypeALPN = 0x0010
sniTypeHostName = 0x00
// maxClientHelloSize is the maximum ClientHello size we'll read.
// Real-world ClientHellos are typically under 1KB but can reach ~16KB with
// many extensions (post-quantum key shares, etc.).
maxClientHelloSize = 16384
)
// ClientHelloInfo holds data extracted from a TLS ClientHello.
type ClientHelloInfo struct {
SNI domain.Domain
ALPN []string
}
// isTLSHandshake reports whether the first byte indicates a TLS handshake record.
func isTLSHandshake(b byte) bool {
return b == recordTypeHandshake
}
// httpMethods lists the first bytes of valid HTTP method tokens.
var httpMethods = [][]byte{
[]byte("GET "),
[]byte("POST"),
[]byte("PUT "),
[]byte("DELE"),
[]byte("HEAD"),
[]byte("OPTI"),
[]byte("PATC"),
[]byte("CONN"),
[]byte("TRAC"),
}
// isHTTPMethod reports whether the peeked bytes look like the start of an HTTP request.
func isHTTPMethod(b []byte) bool {
if len(b) < 4 {
return false
}
for _, m := range httpMethods {
if b[0] == m[0] && b[1] == m[1] && b[2] == m[2] && b[3] == m[3] {
return true
}
}
return false
}
// parseClientHello reads a TLS ClientHello from r and returns SNI and ALPN.
func parseClientHello(r io.Reader) (ClientHelloInfo, error) {
// TLS record header: type(1) + version(2) + length(2)
var recordHeader [5]byte
if _, err := io.ReadFull(r, recordHeader[:]); err != nil {
return ClientHelloInfo{}, fmt.Errorf("read TLS record header: %w", err)
}
if recordHeader[0] != recordTypeHandshake {
return ClientHelloInfo{}, fmt.Errorf("not a TLS handshake record (type=%d)", recordHeader[0])
}
recordLen := int(binary.BigEndian.Uint16(recordHeader[3:5]))
if recordLen < 4 || recordLen > maxClientHelloSize {
return ClientHelloInfo{}, fmt.Errorf("invalid TLS record length: %d", recordLen)
}
// Read the full handshake message
msg := make([]byte, recordLen)
if _, err := io.ReadFull(r, msg); err != nil {
return ClientHelloInfo{}, fmt.Errorf("read handshake message: %w", err)
}
return parseClientHelloMsg(msg)
}
// extractSNI reads a TLS ClientHello from r and returns the SNI hostname.
// Returns empty domain if no SNI extension is present.
func extractSNI(r io.Reader) (domain.Domain, error) {
info, err := parseClientHello(r)
return info.SNI, err
}
// extractSNIFromBytes parses SNI from raw bytes that start with the TLS record header.
func extractSNIFromBytes(data []byte) (domain.Domain, error) {
info, err := parseClientHelloFromBytes(data)
return info.SNI, err
}
// parseClientHelloFromBytes parses a ClientHello from raw bytes starting with the TLS record header.
func parseClientHelloFromBytes(data []byte) (ClientHelloInfo, error) {
if len(data) < 5 {
return ClientHelloInfo{}, fmt.Errorf("data too short for TLS record header")
}
if data[0] != recordTypeHandshake {
return ClientHelloInfo{}, fmt.Errorf("not a TLS handshake record (type=%d)", data[0])
}
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
if recordLen < 4 {
return ClientHelloInfo{}, fmt.Errorf("invalid TLS record length: %d", recordLen)
}
end := 5 + recordLen
if end > len(data) {
return ClientHelloInfo{}, fmt.Errorf("TLS record truncated: need %d, have %d", end, len(data))
}
return parseClientHelloMsg(data[5:end])
}
// parseClientHelloMsg extracts SNI and ALPN from a raw ClientHello handshake message.
// msg starts at the handshake type byte.
func parseClientHelloMsg(msg []byte) (ClientHelloInfo, error) {
if len(msg) < 4 {
return ClientHelloInfo{}, fmt.Errorf("handshake message too short")
}
if msg[0] != handshakeTypeClientHello {
return ClientHelloInfo{}, fmt.Errorf("not a ClientHello (type=%d)", msg[0])
}
// Handshake header: type(1) + length(3)
helloLen := int(msg[1])<<16 | int(msg[2])<<8 | int(msg[3])
if helloLen+4 > len(msg) {
return ClientHelloInfo{}, fmt.Errorf("ClientHello truncated")
}
hello := msg[4 : 4+helloLen]
return parseHelloBody(hello)
}
// parseHelloBody parses the ClientHello body (after handshake header)
// and extracts SNI and ALPN.
func parseHelloBody(hello []byte) (ClientHelloInfo, error) {
// ClientHello structure:
// version(2) + random(32) + session_id_len(1) + session_id(var)
// + cipher_suites_len(2) + cipher_suites(var)
// + compression_len(1) + compression(var)
// + extensions_len(2) + extensions(var)
var info ClientHelloInfo
if len(hello) < 35 {
return info, fmt.Errorf("ClientHello body too short")
}
pos := 2 + 32 // skip version + random
// Skip session ID
if pos >= len(hello) {
return info, fmt.Errorf("ClientHello truncated at session ID")
}
sessionIDLen := int(hello[pos])
pos += 1 + sessionIDLen
// Skip cipher suites
if pos+2 > len(hello) {
return info, fmt.Errorf("ClientHello truncated at cipher suites")
}
cipherLen := int(binary.BigEndian.Uint16(hello[pos : pos+2]))
pos += 2 + cipherLen
// Skip compression methods
if pos >= len(hello) {
return info, fmt.Errorf("ClientHello truncated at compression")
}
compLen := int(hello[pos])
pos += 1 + compLen
// Extensions
if pos+2 > len(hello) {
return info, nil
}
extLen := int(binary.BigEndian.Uint16(hello[pos : pos+2]))
pos += 2
extEnd := pos + extLen
if extEnd > len(hello) {
return info, fmt.Errorf("extensions block truncated")
}
// Walk extensions looking for SNI and ALPN
for pos+4 <= extEnd {
extType := binary.BigEndian.Uint16(hello[pos : pos+2])
extDataLen := int(binary.BigEndian.Uint16(hello[pos+2 : pos+4]))
pos += 4
if pos+extDataLen > extEnd {
return info, fmt.Errorf("extension data truncated")
}
switch extType {
case extensionTypeSNI:
sni, err := parseSNIExtension(hello[pos : pos+extDataLen])
if err != nil {
return info, err
}
info.SNI = sni
case extensionTypeALPN:
info.ALPN = parseALPNExtension(hello[pos : pos+extDataLen])
}
pos += extDataLen
}
return info, nil
}
// parseALPNExtension parses the ALPN extension data and returns protocol names.
// ALPN extension: list_length(2) + entries (each: len(1) + protocol_name(var))
func parseALPNExtension(data []byte) []string {
if len(data) < 2 {
return nil
}
listLen := int(binary.BigEndian.Uint16(data[0:2]))
if listLen+2 > len(data) {
return nil
}
var protocols []string
pos := 2
end := 2 + listLen
for pos < end {
if pos >= len(data) {
break
}
nameLen := int(data[pos])
pos++
if pos+nameLen > end {
break
}
protocols = append(protocols, string(data[pos:pos+nameLen]))
pos += nameLen
}
return protocols
}
// parseSNIExtension parses the SNI extension data and returns the hostname.
func parseSNIExtension(data []byte) (domain.Domain, error) {
// SNI extension: list_length(2) + entries
if len(data) < 2 {
return "", fmt.Errorf("SNI extension too short")
}
listLen := int(binary.BigEndian.Uint16(data[0:2]))
if listLen+2 > len(data) {
return "", fmt.Errorf("SNI list truncated")
}
pos := 2
end := 2 + listLen
for pos+3 <= end {
nameType := data[pos]
nameLen := int(binary.BigEndian.Uint16(data[pos+1 : pos+3]))
pos += 3
if pos+nameLen > end {
return "", fmt.Errorf("SNI name truncated")
}
if nameType == sniTypeHostName {
hostname := string(data[pos : pos+nameLen])
return domain.FromString(hostname)
}
pos += nameLen
}
return "", nil
}

View File

@@ -1,109 +0,0 @@
package inspect
import (
"bytes"
"crypto/tls"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExtractSNI(t *testing.T) {
tests := []struct {
name string
sni string
wantSNI string
wantErr bool
}{
{
name: "standard domain",
sni: "example.com",
wantSNI: "example.com",
},
{
name: "subdomain",
sni: "api.staging.example.com",
wantSNI: "api.staging.example.com",
},
{
name: "mixed case normalized to lowercase",
sni: "Example.COM",
wantSNI: "example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientHello := buildClientHello(t, tt.sni)
sni, err := extractSNI(bytes.NewReader(clientHello))
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantSNI, sni.PunycodeString())
})
}
}
func TestExtractSNI_NotTLS(t *testing.T) {
// HTTP request instead of TLS
data := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
_, err := extractSNI(bytes.NewReader(data))
require.Error(t, err)
assert.Contains(t, err.Error(), "not a TLS handshake")
}
func TestExtractSNI_Truncated(t *testing.T) {
// Just the record header, no body
data := []byte{0x16, 0x03, 0x01, 0x00, 0x05}
_, err := extractSNI(bytes.NewReader(data))
require.Error(t, err)
}
func TestExtractSNIFromBytes(t *testing.T) {
clientHello := buildClientHello(t, "test.example.com")
sni, err := extractSNIFromBytes(clientHello)
require.NoError(t, err)
assert.Equal(t, "test.example.com", sni.PunycodeString())
}
// buildClientHello generates a real TLS ClientHello with the given SNI.
func buildClientHello(t *testing.T, serverName string) []byte {
t.Helper()
// Use a pipe to capture the ClientHello bytes
clientConn, serverConn := net.Pipe()
done := make(chan []byte, 1)
go func() {
buf := make([]byte, 4096)
n, _ := serverConn.Read(buf)
done <- buf[:n]
serverConn.Close()
}()
tlsConn := tls.Client(clientConn, &tls.Config{
ServerName: serverName,
InsecureSkipVerify: true,
})
// Trigger the handshake (will fail since server isn't TLS, but we capture the ClientHello)
go func() {
_ = tlsConn.Handshake()
tlsConn.Close()
}()
clientHello := <-done
clientConn.Close()
require.True(t, len(clientHello) > 5, "ClientHello too short")
require.Equal(t, byte(0x16), clientHello[0], "not a TLS handshake record")
return clientHello
}

View File

@@ -1,287 +0,0 @@
package inspect
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/netip"
"github.com/netbirdio/netbird/shared/management/domain"
)
// handleTLS processes a TLS connection for the kernel-mode path: extracts SNI,
// evaluates rules, and handles the connection internally.
// In envoy mode, allowed connections are forwarded to envoy instead of direct relay.
func (p *Proxy) handleTLS(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) error {
result, err := p.inspectTLS(ctx, pconn, dst, src)
if err != nil {
return err
}
if result.PassthroughConn != nil {
p.mu.RLock()
envoy := p.envoy
p.mu.RUnlock()
if envoy != nil {
return p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
}
return p.tlsPassthrough(ctx, pconn, dst, "")
}
return nil
}
// inspectTLS extracts SNI, evaluates rules, and returns the result.
// For ActionAllow: returns the peekConn as PassthroughConn (caller relays).
// For ActionBlock/ActionInspect: handles internally and returns nil PassthroughConn.
func (p *Proxy) inspectTLS(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) (InspectResult, error) {
// The first 5 bytes (TLS record header) are already peeked.
// Extend to read the full TLS record so bytes remain in the buffer for passthrough.
peeked := pconn.Peeked()
recordLen := int(peeked[3])<<8 | int(peeked[4])
if _, err := pconn.PeekMore(5 + recordLen); err != nil {
return InspectResult{}, fmt.Errorf("read TLS record: %w", err)
}
hello, err := parseClientHelloFromBytes(pconn.Peeked())
if err != nil {
return InspectResult{}, fmt.Errorf("parse ClientHello: %w", err)
}
sni := hello.SNI
proto := protoFromALPN(hello.ALPN)
// Connection-level evaluation: pass empty path.
action := p.evaluateAction(src.IP, sni, dst, proto, "")
// If any rule for this domain has path patterns, force inspect so paths can
// be checked per-request after MITM decryption.
if action == ActionAllow && p.rules.HasPathRulesForDomain(sni) {
p.log.Debugf("upgrading to inspect for %s (path rules exist)", sni.PunycodeString())
action = ActionInspect
}
// Snapshot cert provider under lock for use in this connection.
p.mu.RLock()
certs := p.certs
p.mu.RUnlock()
switch action {
case ActionBlock:
p.log.Debugf("block: TLS to %s (SNI=%s)", dst, sni.PunycodeString())
if certs != nil {
return InspectResult{Action: ActionBlock}, p.tlsBlockPage(ctx, pconn, sni, certs)
}
return InspectResult{Action: ActionBlock}, ErrBlocked
case ActionAllow:
p.log.Tracef("allow: TLS passthrough to %s (SNI=%s)", dst, sni.PunycodeString())
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
case ActionInspect:
if certs == nil {
p.log.Warnf("allow: %s (inspect requested but no MITM CA configured)", sni.PunycodeString())
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
}
err := p.tlsMITM(ctx, pconn, dst, sni, src, certs)
return InspectResult{Action: ActionInspect}, err
default:
p.log.Warnf("block: unknown action %q for %s", action, sni.PunycodeString())
return InspectResult{Action: ActionBlock}, ErrBlocked
}
}
// tlsBlockPage completes a MITM TLS handshake with the client using a dynamic
// certificate, then serves an HTTP 403 block page so the user sees a clear
// message instead of a cryptic SSL error.
func (p *Proxy) tlsBlockPage(ctx context.Context, pconn *peekConn, sni domain.Domain, certs *CertProvider) error {
hostname := sni.PunycodeString()
// Force HTTP/1.1 only: block pages are simple responses, no need for h2
tlsCfg := certs.GetTLSConfig()
tlsCfg.NextProtos = []string{"http/1.1"}
clientTLS := tls.Server(pconn, tlsCfg)
if err := clientTLS.HandshakeContext(ctx); err != nil {
// Client may not trust our CA, handshake fails. That's expected.
return fmt.Errorf("block page TLS handshake for %s: %w", hostname, err)
}
defer func() {
if err := clientTLS.Close(); err != nil {
p.log.Debugf("close block page TLS for %s: %v", hostname, err)
}
}()
writeBlockResponse(clientTLS, nil, sni)
return ErrBlocked
}
// tlsPassthrough connects to the destination and relays encrypted traffic
// without decryption. The peeked ClientHello bytes are replayed.
func (p *Proxy) tlsPassthrough(ctx context.Context, pconn *peekConn, dst netip.AddrPort, sni domain.Domain) error {
remote, err := p.dialTCP(ctx, dst)
if err != nil {
return fmt.Errorf("dial %s: %w", dst, err)
}
defer func() {
if err := remote.Close(); err != nil {
p.log.Debugf("close remote for %s: %v", dst, err)
}
}()
p.log.Tracef("allow: TLS passthrough to %s (SNI=%s)", dst, sni.PunycodeString())
return relay(ctx, pconn, remote)
}
// tlsMITM terminates the client TLS connection with a dynamic certificate,
// establishes a new TLS connection to the real destination, and runs the
// HTTP inspection pipeline on the decrypted traffic.
func (p *Proxy) tlsMITM(ctx context.Context, pconn *peekConn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, certs *CertProvider) error {
hostname := sni.PunycodeString()
// TLS handshake with client using dynamic cert
clientTLS := tls.Server(pconn, certs.GetTLSConfig())
if err := clientTLS.HandshakeContext(ctx); err != nil {
return fmt.Errorf("client TLS handshake for %s: %w", hostname, err)
}
defer func() {
if err := clientTLS.Close(); err != nil {
p.log.Debugf("close client TLS for %s: %v", hostname, err)
}
}()
// TLS connection to real destination
remoteTLS, err := p.dialTLS(ctx, dst, hostname)
if err != nil {
return fmt.Errorf("dial TLS %s (%s): %w", dst, hostname, err)
}
defer func() {
if err := remoteTLS.Close(); err != nil {
p.log.Debugf("close remote TLS for %s: %v", hostname, err)
}
}()
negotiatedProto := clientTLS.ConnectionState().NegotiatedProtocol
p.log.Tracef("inspect: MITM established for %s (proto=%s)", hostname, negotiatedProto)
return p.inspectHTTP(ctx, clientTLS, remoteTLS, dst, sni, src, negotiatedProto)
}
// dialTLS connects to the destination with TLS, verifying the real server certificate.
func (p *Proxy) dialTLS(ctx context.Context, dst netip.AddrPort, serverName string) (net.Conn, error) {
rawConn, err := p.dialTCP(ctx, dst)
if err != nil {
return nil, err
}
tlsConn := tls.Client(rawConn, &tls.Config{
ServerName: serverName,
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
})
if err := tlsConn.HandshakeContext(ctx); err != nil {
if closeErr := rawConn.Close(); closeErr != nil {
p.log.Debugf("close raw conn after TLS handshake failure: %v", closeErr)
}
return nil, fmt.Errorf("TLS handshake with %s: %w", serverName, err)
}
return tlsConn, nil
}
// protoFromALPN maps TLS ALPN protocol names to proxy ProtoType.
// Falls back to ProtoHTTPS when no recognized ALPN is present.
func protoFromALPN(alpn []string) ProtoType {
for _, p := range alpn {
switch p {
case "h2":
return ProtoH2
case "h3": // unlikely in TLS, but handle anyway
return ProtoH3
}
}
// No ALPN or only "http/1.1": treat as HTTPS
return ProtoHTTPS
}
// relay copies data bidirectionally between client and remote until one
// side closes or the context is cancelled.
func relay(ctx context.Context, client, remote net.Conn) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errCh := make(chan error, 2)
go func() {
_, err := io.Copy(remote, client)
cancel()
errCh <- err
}()
go func() {
_, err := io.Copy(client, remote)
cancel()
errCh <- err
}()
var firstErr error
for range 2 {
if err := <-errCh; err != nil && firstErr == nil {
if !isClosedErr(err) {
firstErr = err
}
}
}
return firstErr
}
// evaluateAction runs rule evaluation and resolves the effective action.
// Pass empty path for connection-level (TLS), non-empty for request-level (HTTP).
func (p *Proxy) evaluateAction(src netip.Addr, sni domain.Domain, dst netip.AddrPort, proto ProtoType, path string) Action {
return p.rules.Evaluate(src, sni, dst.Addr(), dst.Port(), proto, path)
}
// dialTCP dials the destination, blocking connections to loopback, link-local,
// multicast, and WG overlay network addresses.
func (p *Proxy) dialTCP(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
ip := dst.Addr().Unmap()
if err := p.validateDialTarget(ip); err != nil {
return nil, fmt.Errorf("dial %s: %w", dst, err)
}
return p.dialer.DialContext(ctx, "tcp", dst.String())
}
// validateDialTarget blocks destinations that should never be dialed by the proxy.
// Mirrors the route validation in systemops.validateRoute.
func (p *Proxy) validateDialTarget(addr netip.Addr) error {
switch {
case !addr.IsValid():
return fmt.Errorf("invalid address")
case addr.IsLoopback():
return fmt.Errorf("loopback address not allowed")
case addr.IsLinkLocalUnicast(), addr.IsLinkLocalMulticast(), addr.IsInterfaceLocalMulticast():
return fmt.Errorf("link-local address not allowed")
case addr.IsMulticast():
return fmt.Errorf("multicast address not allowed")
case p.wgNetwork.IsValid() && p.wgNetwork.Contains(addr):
return fmt.Errorf("overlay network address not allowed")
case p.localIPs != nil && p.localIPs.IsLocalIP(addr):
return fmt.Errorf("local address not allowed")
}
return nil
}
func isClosedErr(err error) bool {
if err == nil {
return false
}
return err == io.EOF ||
err == io.ErrClosedPipe ||
err == net.ErrClosed ||
err == context.Canceled
}

View File

@@ -19,9 +19,6 @@ import (
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
@@ -138,7 +135,6 @@ func TestDefaultManager(t *testing.T) {
func TestDefaultManagerStateless(t *testing.T) {
// stateless currently only in userspace, so we have to disable kernel
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
t.Setenv("NB_DISABLE_CONNTRACK", "true")
networkMap := &mgmProto.NetworkMap{
@@ -198,7 +194,6 @@ func TestDefaultManagerStateless(t *testing.T) {
// This tests the full ACL manager -> uspfilter integration.
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
@@ -263,7 +258,6 @@ func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
// up when they're removed from the network map in a subsequent update.
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
@@ -345,7 +339,6 @@ func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
// one added without leaking.
func TestRuleUpdateChangingAction(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()

View File

@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
err := a.doMgmLogin(client, ctx, pubSSHKey)
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) {
needsLogin = true
return nil
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
err := a.doMgmLogin(client, ctx, pubSSHKey)
if isRegistrationNeeded(err) {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
@@ -201,7 +201,13 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
protoFlow, err := client.GetPKCEAuthorizationFlow()
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
@@ -215,7 +221,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
ClientSecret: protoConfig.GetClientSecret(),
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
@@ -240,7 +246,13 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
protoFlow, err := client.GetDeviceAuthorizationFlow()
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
@@ -254,7 +266,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
ClientSecret: protoConfig.GetClientSecret(),
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
@@ -280,16 +292,28 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo)
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
return err
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
@@ -298,7 +322,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
a.setSystemInfoFlags(info)
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err

View File

@@ -23,13 +23,12 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/metrics"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/updater"
"github.com/netbirdio/netbird/client/internal/updater/installer"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
@@ -44,19 +43,14 @@ import (
"github.com/netbirdio/netbird/version"
)
// androidRunOverride is set on Android to inject mobile dependencies
// when using embed.Client (which calls Run() with empty MobileDependency).
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
doInitialAutoUpdate bool
engine *Engine
engineMutex sync.Mutex
clientMetrics *metrics.ClientMetrics
updateManager *updater.Manager
engine *Engine
engineMutex sync.Mutex
persistSyncResponse bool
}
@@ -65,24 +59,19 @@ func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
doInitialAutoUpdate: doInitalAutoUpdate,
engineMutex: sync.Mutex{},
}
}
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
c.updateManager = um
}
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
if androidRunOverride != nil {
return androidRunOverride(c, runningChan, logPath)
}
return c.run(MobileDependency{}, runningChan, logPath)
}
@@ -111,7 +100,6 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
dnsAddresses []netip.AddrPort,
stateFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
@@ -121,7 +109,6 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
HostDNSAddresses: dnsAddresses,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")
@@ -144,34 +131,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
}()
// Stop metrics push on exit
defer func() {
if c.clientMetrics != nil {
c.clientMetrics.StopPush()
}
}()
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
nbnet.Init()
// Initialize metrics once at startup (always active for debug bundles)
if c.clientMetrics == nil {
agentInfo := metrics.AgentInfo{
DeploymentType: metrics.DeploymentTypeUnknown,
Version: version.NetbirdVersion(),
OS: runtime.GOOS,
Arch: runtime.GOARCH,
}
c.clientMetrics = metrics.NewClientMetrics(agentInfo)
log.Debugf("initialized client metrics")
// Start metrics push if enabled (uses daemon context, persists across engine restarts)
if metrics.IsMetricsPushEnabled() {
c.clientMetrics.StartPush(c.ctx, metrics.PushConfigFromEnv())
}
}
backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second,
RandomizationFactor: 1,
@@ -224,13 +187,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
stateManager := statemanager.New(path)
stateManager.RegisterState(&sshconfig.ShutdownState{})
if c.updateManager != nil {
c.updateManager.CheckUpdateSuccess(c.ctx)
}
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
if err == nil {
updateManager.CheckUpdateSuccess(c.ctx)
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
}
}
defer c.statusRecorder.ClientStop()
@@ -258,16 +222,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier)
// Update metrics with actual deployment type after connection
deploymentType := metrics.DetermineDeploymentType(mgmClient.GetServerURL())
agentInfo := metrics.AgentInfo{
DeploymentType: deploymentType,
Version: version.NetbirdVersion(),
OS: runtime.GOOS,
Arch: runtime.GOARCH,
}
c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String())
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
defer func() {
if err = mgmClient.Close(); err != nil {
@@ -276,10 +230,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}()
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
loginStarted := time.Now()
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
if err != nil {
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false)
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
@@ -288,7 +240,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
return wrapErr(err)
}
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), true)
c.statusRecorder.MarkManagementConnected()
localPeerState := peer.LocalPeerState{
@@ -357,16 +308,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
engine := NewEngine(engineCtx, cancel, engineConfig, EngineServices{
SignalClient: signalClient,
MgmClient: mgmClient,
RelayManager: relayManager,
StatusRecorder: c.statusRecorder,
Checks: checks,
StateManager: stateManager,
UpdateManager: c.updateManager,
ClientMetrics: c.clientMetrics,
}, mobileDependency)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
@@ -376,15 +318,21 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
// AutoUpdate will be true when the user click on "Connect" menu on the UI
if c.doInitialAutoUpdate {
log.Infof("start engine by ui, run auto-update check")
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
c.doInitialAutoUpdate = false
}
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningChan != nil {
select {
case <-runningChan:
default:
close(runningChan)
}
close(runningChan)
runningChan = nil
}
<-engineCtx.Done()
@@ -562,9 +510,6 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
MTU: selectMTU(config.MTU, peerConfig.Mtu),
LogPath: logPath,
InspectionCACertPath: config.InspectionCACertPath,
InspectionCAKeyPath: config.InspectionCAKeyPath,
ProfileConfig: config,
}
@@ -622,6 +567,12 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
@@ -640,7 +591,12 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {
return nil, err
}
return loginResp, nil
}
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {

View File

@@ -1,73 +0,0 @@
//go:build android
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
// It returns an empty interface list, which means ICE P2P candidates won't be
// discovered — connections will fall back to relay. Applications that need P2P
// should provide a real implementation via runOnAndroidEmbed that uses
// Android's ConnectivityManager to enumerate network interfaces.
type noopIFaceDiscover struct{}
func (noopIFaceDiscover) IFaces() (string, error) {
// Return empty JSON array — no local interfaces advertised for ICE.
// This is intentional: without Android's ConnectivityManager, we cannot
// reliably enumerate interfaces (netlink is restricted on Android 11+).
// Relay connections still work; only P2P hole-punching is disabled.
return "[]", nil
}
// noopNetworkChangeListener is a stub for embed.Client on Android.
// Network change events are ignored since the embed client manages its own
// reconnection logic via the engine's built-in retry mechanism.
type noopNetworkChangeListener struct{}
func (noopNetworkChangeListener) OnNetworkChanged(string) {
// No-op: embed.Client relies on the engine's internal reconnection
// logic rather than OS-level network change notifications.
}
func (noopNetworkChangeListener) SetInterfaceIP(string) {
// No-op: in netstack mode, the overlay IP is managed by the userspace
// network stack, not by OS-level interface configuration.
}
// noopDnsReadyListener is a stub for embed.Client on Android.
// DNS readiness notifications are not needed in netstack/embed mode
// since system DNS is disabled and DNS resolution happens externally.
type noopDnsReadyListener struct{}
func (noopDnsReadyListener) OnReady() {
// No-op: embed.Client does not need DNS readiness notifications.
// System DNS is disabled in netstack mode.
}
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
var _ dns.ReadyListener = noopDnsReadyListener{}
func init() {
// Wire up the default override so embed.Client.Start() works on Android
// with netstack mode. Provides complete no-op stubs for all mobile
// dependencies so the engine's existing Android code paths work unchanged.
// Applications that need P2P ICE or real DNS should replace this by
// setting androidRunOverride before calling Start().
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
return c.runOnAndroidEmbed(
noopIFaceDiscover{},
noopNetworkChangeListener{},
[]netip.AddrPort{},
noopDnsReadyListener{},
runningChan,
logPath,
)
}
}

View File

@@ -1,32 +0,0 @@
//go:build android
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
// so embed.Client.Start() can detect when the engine is ready.
// It provides complete MobileDependency so the engine's existing
// Android code paths work unchanged.
func (c *ConnectClient) runOnAndroidEmbed(
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
runningChan chan struct{},
logPath string,
) error {
mobileDependency := MobileDependency{
IFaceDiscover: iFaceDiscover,
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return c.run(mobileDependency, runningChan, logPath)
}

View File

@@ -1,60 +0,0 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
var scanDir = "/var/run/netbird"
// setScanDir overrides the scan directory (used by tests).
func setScanDir(dir string) {
scanDir = dir
}
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
// mismatch between the netbird@.service template (which places the socket under
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
func ResolveUnixDaemonAddr(addr string) string {
if !strings.HasPrefix(addr, "unix://") {
return addr
}
sockPath := strings.TrimPrefix(addr, "unix://")
if _, err := os.Stat(sockPath); err == nil {
return addr
}
entries, err := os.ReadDir(scanDir)
if err != nil {
return addr
}
var found []string
for _, e := range entries {
if e.IsDir() {
continue
}
if strings.HasSuffix(e.Name(), ".sock") {
found = append(found, filepath.Join(scanDir, e.Name()))
}
}
switch len(found) {
case 1:
resolved := "unix://" + found[0]
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
return resolved
case 0:
return addr
default:
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
return addr
}
}

View File

@@ -1,8 +0,0 @@
//go:build windows || ios || android
package daemonaddr
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
func ResolveUnixDaemonAddr(addr string) string {
return addr
}

View File

@@ -1,121 +0,0 @@
//go:build !windows && !ios && !android
package daemonaddr
import (
"os"
"path/filepath"
"testing"
)
// createSockFile creates a regular file with a .sock extension.
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
// sufficient and avoids Unix socket path-length limits on macOS.
func createSockFile(t *testing.T, path string) {
t.Helper()
if err := os.WriteFile(path, nil, 0o600); err != nil {
t.Fatalf("failed to create test sock file at %s: %v", path, err)
}
}
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
tmp := t.TempDir()
sock := filepath.Join(tmp, "netbird.sock")
createSockFile(t, sock)
addr := "unix://" + sock
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
tmp := t.TempDir()
// Default socket does not exist
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
// Create a scan dir with one socket
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
instanceSock := filepath.Join(sd, "main.sock")
createSockFile(t, instanceSock)
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
expected := "unix://" + instanceSock
if got != expected {
t.Errorf("expected %s, got %s", expected, got)
}
}
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
createSockFile(t, filepath.Join(sd, "main.sock"))
createSockFile(t, filepath.Join(sd, "other.sock"))
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
sd := filepath.Join(tmp, "netbird")
if err := os.MkdirAll(sd, 0o755); err != nil {
t.Fatal(err)
}
origScanDir := scanDir
setScanDir(sd)
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
addr := "tcp://127.0.0.1:41731"
got := ResolveUnixDaemonAddr(addr)
if got != addr {
t.Errorf("expected %s, got %s", addr, got)
}
}
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
tmp := t.TempDir()
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
origScanDir := scanDir
setScanDir(filepath.Join(tmp, "nonexistent"))
t.Cleanup(func() { setScanDir(origScanDir) })
got := ResolveUnixDaemonAddr(defaultAddr)
if got != defaultAddr {
t.Errorf("expected original %s, got %s", defaultAddr, got)
}
}

View File

@@ -25,13 +25,13 @@ import (
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updater/installer"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
const readmeContent = `Netbird debug bundle
@@ -53,8 +53,6 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states for the active profile.
service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
mutex.prof: Mutex profiling information.
goroutine.prof: Goroutine profiling information.
block.prof: Block profiling information.
@@ -221,11 +219,6 @@ const (
darwinStdoutLogPath = "/var/log/netbird.err.log"
)
// MetricsExporter is an interface for exporting metrics
type MetricsExporter interface {
Export(w io.Writer) error
}
type BundleGenerator struct {
anonymizer *anonymize.Anonymizer
@@ -236,7 +229,6 @@ type BundleGenerator struct {
logPath string
cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
anonymize bool
includeSystemInfo bool
@@ -258,7 +250,6 @@ type GeneratorDependencies struct {
LogPath string
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
ClientMetrics MetricsExporter
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
@@ -277,7 +268,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
logPath: deps.LogPath,
cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
anonymize: cfg.Anonymize,
includeSystemInfo: cfg.IncludeSystemInfo,
@@ -361,14 +351,6 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
}
if err := g.addServiceParams(); err != nil {
log.Errorf("failed to add service params to debug bundle: %v", err)
}
if err := g.addMetrics(); err != nil {
log.Errorf("failed to add metrics to debug bundle: %v", err)
}
if err := g.addWgShow(); err != nil {
log.Errorf("failed to add wg show output: %v", err)
}
@@ -436,10 +418,7 @@ func (g *BundleGenerator) addStatus() error {
fullStatus := g.statusRecorder.GetFullStatus()
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{
Anonymize: g.anonymize,
ProfileName: profName,
})
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
statusOutput := overview.FullDetailSummary()
statusReader := strings.NewReader(statusOutput)
@@ -494,90 +473,6 @@ func (g *BundleGenerator) addConfig() error {
return nil
}
const (
serviceParamsFile = "service.json"
serviceParamsBundle = "service_params.json"
maskedValue = "***"
envVarPrefix = "NB_"
jsonKeyManagementURL = "management_url"
jsonKeyServiceEnv = "service_env_vars"
)
var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
func (g *BundleGenerator) addServiceParams() error {
path := filepath.Join(configs.StateDir, serviceParamsFile)
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("read service params: %w", err)
}
var params map[string]any
if err := json.Unmarshal(data, &params); err != nil {
return fmt.Errorf("parse service params: %w", err)
}
if g.anonymize {
if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
}
}
g.sanitizeServiceEnvVars(params)
sanitizedData, err := json.MarshalIndent(params, "", " ")
if err != nil {
return fmt.Errorf("marshal sanitized service params: %w", err)
}
if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
return fmt.Errorf("add service params to zip: %w", err)
}
return nil
}
// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
if !ok {
return
}
sanitized := make(map[string]any, len(envVars))
for k, v := range envVars {
val, _ := v.(string)
switch {
case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
sanitized[k] = maskedValue
case g.anonymize:
sanitized[k] = g.anonymizer.AnonymizeString(val)
default:
sanitized[k] = val
}
}
params[jsonKeyServiceEnv] = sanitized
}
// isSensitiveEnvVar returns true for env var names that may contain secrets.
func isSensitiveEnvVar(key string) bool {
lower := strings.ToLower(key)
for _, s := range sensitiveEnvSubstrings {
if strings.Contains(lower, s) {
return true
}
}
return false
}
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n")
@@ -849,30 +744,6 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
return nil
}
func (g *BundleGenerator) addMetrics() error {
if g.clientMetrics == nil {
log.Debugf("skipping metrics in debug bundle: no metrics collector")
return nil
}
var buf bytes.Buffer
if err := g.clientMetrics.Export(&buf); err != nil {
return fmt.Errorf("export metrics: %w", err)
}
if buf.Len() == 0 {
log.Debugf("skipping metrics.txt in debug bundle: no metrics data")
return nil
}
if err := g.addFileToZip(&buf, "metrics.txt"); err != nil {
return fmt.Errorf("add metrics file to zip: %w", err)
}
log.Debugf("added metrics to debug bundle")
return nil
}
func (g *BundleGenerator) addLogfile() error {
if g.logPath == "" {
log.Debugf("skipping empty log file in debug bundle")

View File

@@ -1,12 +1,8 @@
package debug
import (
"archive/zip"
"bytes"
"encoding/json"
"net"
"os"
"path/filepath"
"strings"
"testing"
@@ -14,7 +10,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -425,226 +420,6 @@ func TestAnonymizeNetworkMap(t *testing.T) {
}
}
func TestIsSensitiveEnvVar(t *testing.T) {
tests := []struct {
key string
sensitive bool
}{
{"NB_SETUP_KEY", true},
{"NB_API_TOKEN", true},
{"NB_CLIENT_SECRET", true},
{"NB_PASSWORD", true},
{"NB_CREDENTIAL", true},
{"NB_LOG_LEVEL", false},
{"NB_MANAGEMENT_URL", false},
{"NB_HOSTNAME", false},
{"HOME", false},
{"PATH", false},
}
for _, tt := range tests {
t.Run(tt.key, func(t *testing.T) {
assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
})
}
}
func TestSanitizeServiceEnvVars(t *testing.T) {
tests := []struct {
name string
anonymize bool
input map[string]any
check func(t *testing.T, params map[string]any)
}{
{
name: "no env vars key",
anonymize: false,
input: map[string]any{"management_url": "https://mgmt.example.com"},
check: func(t *testing.T, params map[string]any) {
t.Helper()
assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
_, ok := params[jsonKeyServiceEnv]
assert.False(t, ok, "service_env_vars should not be added")
},
},
{
name: "non-NB vars are masked",
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"HOME": "/root",
"PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
},
},
{
name: "sensitive NB vars are masked",
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
},
},
{
name: "safe NB vars anonymized when anonymize is true",
anonymize: true,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
"NB_LOG_LEVEL": "debug",
"NB_SETUP_KEY": "secret",
"SOME_OTHER": "val",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
// Safe NB_ values should be anonymized (not the original, not masked)
mgmtVal := env["NB_MANAGEMENT_URL"].(string)
assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
logVal := env["NB_LOG_LEVEL"].(string)
assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
// Sensitive and non-NB_ still masked
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
assert.Equal(t, maskedValue, env["SOME_OTHER"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
g := &BundleGenerator{
anonymize: tt.anonymize,
anonymizer: anonymizer,
}
g.sanitizeServiceEnvVars(tt.input)
tt.check(t, tt.input)
})
}
}
func TestAddServiceParams(t *testing.T) {
t.Run("missing service.json returns nil", func(t *testing.T) {
g := &BundleGenerator{
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
}
origStateDir := configs.StateDir
configs.StateDir = t.TempDir()
t.Cleanup(func() { configs.StateDir = origStateDir })
err := g.addServiceParams()
assert.NoError(t, err)
})
t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
dir := t.TempDir()
origStateDir := configs.StateDir
configs.StateDir = dir
t.Cleanup(func() { configs.StateDir = origStateDir })
input := map[string]any{
jsonKeyManagementURL: "https://api.example.com:443",
jsonKeyServiceEnv: map[string]any{
"NB_LOG_LEVEL": "trace",
},
}
data, err := json.Marshal(input)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
g := &BundleGenerator{
anonymize: true,
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
archive: zw,
}
require.NoError(t, g.addServiceParams())
require.NoError(t, zw.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
require.Len(t, zr.File, 1)
assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
rc, err := zr.File[0].Open()
require.NoError(t, err)
defer rc.Close()
var result map[string]any
require.NoError(t, json.NewDecoder(rc).Decode(&result))
mgmt := result[jsonKeyManagementURL].(string)
assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
assert.NotEmpty(t, mgmt)
env := result[jsonKeyServiceEnv].(map[string]any)
assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
})
t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
dir := t.TempDir()
origStateDir := configs.StateDir
configs.StateDir = dir
t.Cleanup(func() { configs.StateDir = origStateDir })
input := map[string]any{
jsonKeyManagementURL: "https://api.example.com:443",
}
data, err := json.Marshal(input)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
g := &BundleGenerator{
anonymize: false,
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
archive: zw,
}
require.NoError(t, g.addServiceParams())
require.NoError(t, zw.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
rc, err := zr.File[0].Open()
require.NoError(t, err)
defer rc.Close()
var result map[string]any
require.NoError(t, json.NewDecoder(rc).Decode(&result))
assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
})
}
// Helper function to check if IP is in CGNAT range
func isInCGNATRange(ip net.IP) bool {
cgnat := net.IPNet{

View File

@@ -73,9 +73,6 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
return nil
}
w.response = m
if m.MsgHdr.Truncated {
w.SetMeta("truncated", "true")
}
return w.ResponseWriter.WriteMsg(m)
}
@@ -198,14 +195,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
startTime := time.Now()
requestID := resutil.GenerateRequestID()
fields := log.Fields{
logger := log.WithFields(log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
})
question := r.Question[0]
qname := strings.ToLower(question.Name)
@@ -268,9 +261,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
cw.response.Len(), meta, time.Since(startTime))
meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {

View File

@@ -14,8 +14,6 @@ import (
"strings"
"sync"
"github.com/hashicorp/go-multierror"
nberrors "github.com/netbirdio/netbird/client/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
@@ -24,7 +22,6 @@ import (
const (
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS"
globalIPv4State = "State:/Network/Global/IPv4"
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
keySupplementalMatchDomains = "SupplementalMatchDomains"
@@ -38,14 +35,6 @@ const (
searchSuffix = "Search"
matchSuffix = "Match"
localSuffix = "Local"
// maxDomainsPerResolverEntry is the max number of domains per scutil resolver key.
// scutil's d.add has maxArgs=101 (key + * + 99 values), so 99 is the hard cap.
maxDomainsPerResolverEntry = 50
// maxDomainBytesPerResolverEntry is the max total bytes of domain strings per key.
// scutil has an undocumented ~2048 byte value buffer; we stay well under it.
maxDomainBytesPerResolverEntry = 1500
)
type systemConfigurator struct {
@@ -95,23 +84,28 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
}
if err := s.removeKeysContaining(matchSuffix); err != nil {
log.Warnf("failed to remove old match keys: %v", err)
}
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
var err error
if len(matchDomains) != 0 {
if err := s.addBatchedDomains(matchSuffix, matchDomains, config.ServerIP, config.ServerPort, false); err != nil {
return fmt.Errorf("add match domains: %w", err)
}
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
} else {
log.Infof("removing match domains from the system")
err = s.removeKeyFromSystemConfig(matchKey)
}
if err != nil {
return fmt.Errorf("add match domains: %w", err)
}
s.updateState(stateManager)
if err := s.removeKeysContaining(searchSuffix); err != nil {
log.Warnf("failed to remove old search keys: %v", err)
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
if len(searchDomains) != 0 {
if err := s.addBatchedDomains(searchSuffix, searchDomains, config.ServerIP, config.ServerPort, true); err != nil {
return fmt.Errorf("add search domains: %w", err)
}
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
} else {
log.Infof("removing search domains from the system")
err = s.removeKeyFromSystemConfig(searchKey)
}
if err != nil {
return fmt.Errorf("add search domains: %w", err)
}
s.updateState(stateManager)
@@ -155,7 +149,8 @@ func (s *systemConfigurator) restoreHostDNS() error {
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
if len(s.createdKeys) == 0 {
return s.discoverExistingKeys()
// return defaults for startup calls
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
}
keys := make([]string, 0, len(s.createdKeys))
@@ -165,47 +160,6 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
return keys
}
// discoverExistingKeys probes scutil for all NetBird DNS keys that may exist.
// This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown).
func (s *systemConfigurator) discoverExistingKeys() []string {
dnsKeys, err := getSystemDNSKeys()
if err != nil {
log.Errorf("failed to get system DNS keys: %v", err)
return nil
}
var keys []string
for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} {
key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix)
if strings.Contains(dnsKeys, key) {
keys = append(keys, key)
}
}
for _, suffix := range []string{searchSuffix, matchSuffix} {
for i := 0; ; i++ {
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
if !strings.Contains(dnsKeys, key) {
break
}
keys = append(keys, key)
}
}
return keys
}
// getSystemDNSKeys gets all DNS keys
func getSystemDNSKeys() (string, error) {
command := "list .*DNS\nquit\n"
out, err := runSystemConfigCommand(command)
if err != nil {
return "", err
}
return string(out), nil
}
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
line := buildRemoveKeyOperation(key)
_, err := runSystemConfigCommand(wrapCommand(line))
@@ -230,11 +184,12 @@ func (s *systemConfigurator) addLocalDNS() error {
return nil
}
domainsStr := strings.Join(s.systemDNSSettings.Domains, " ")
if err := s.addDNSState(localKey, domainsStr, s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, true); err != nil {
return fmt.Errorf("add local dns state: %w", err)
if err := s.addSearchDomains(
localKey,
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
); err != nil {
return fmt.Errorf("add search domains: %w", err)
}
s.createdKeys[localKey] = struct{}{}
return nil
}
@@ -325,77 +280,28 @@ func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(s.origNameservers)
}
// splitDomainsIntoBatches splits domains into batches respecting both element count and byte size limits.
func splitDomainsIntoBatches(domains []string) [][]string {
if len(domains) == 0 {
return nil
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {
return fmt.Errorf("add dns state: %w", err)
}
var batches [][]string
var current []string
currentBytes := 0
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
for _, d := range domains {
domainLen := len(d)
newBytes := currentBytes + domainLen
if currentBytes > 0 {
newBytes++ // space separator
}
s.createdKeys[key] = struct{}{}
if len(current) > 0 && (len(current) >= maxDomainsPerResolverEntry || newBytes > maxDomainBytesPerResolverEntry) {
batches = append(batches, current)
current = nil
currentBytes = 0
}
current = append(current, d)
if currentBytes > 0 {
currentBytes += 1 + domainLen
} else {
currentBytes = domainLen
}
}
if len(current) > 0 {
batches = append(batches, current)
}
return batches
return nil
}
// removeKeysContaining removes all created keys that contain the given substring.
func (s *systemConfigurator) removeKeysContaining(suffix string) error {
var toRemove []string
for key := range s.createdKeys {
if strings.Contains(key, suffix) {
toRemove = append(toRemove, key)
}
}
var multiErr *multierror.Error
for _, key := range toRemove {
if err := s.removeKeyFromSystemConfig(key); err != nil {
multiErr = multierror.Append(multiErr, fmt.Errorf("couldn't remove key %s: %w", key, err))
}
}
return nberrors.FormatErrorOrNil(multiErr)
}
// addBatchedDomains splits domains into batches and creates indexed scutil keys for each batch.
func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, ip netip.Addr, port int, enableSearch bool) error {
batches := splitDomainsIntoBatches(domains)
for i, batch := range batches {
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
domainsStr := strings.Join(batch, " ")
if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil {
return fmt.Errorf("add dns state for batch %d: %w", i, err)
}
s.createdKeys[key] = struct{}{}
func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error {
err := s.addDNSState(key, domains, dnsServer, port, false)
if err != nil {
return fmt.Errorf("add dns state: %w", err)
}
log.Infof("added %d %s domains across %d resolver entries", len(domains), suffix, len(batches))
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
s.createdKeys[key] = struct{}{}
return nil
}
@@ -458,6 +364,7 @@ func (s *systemConfigurator) flushDNSCache() error {
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
}
log.Info("flushed DNS cache")
return nil
}

View File

@@ -3,10 +3,7 @@
package dns
import (
"bufio"
"bytes"
"context"
"fmt"
"net/netip"
"os/exec"
"path/filepath"
@@ -52,22 +49,17 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
require.NoError(t, sm.PersistState(context.Background()))
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
// Collect all created keys for cleanup verification
createdKeys := make([]string, 0, len(configurator.createdKeys))
for key := range configurator.createdKeys {
createdKeys = append(createdKeys, key)
}
defer func() {
for _, key := range createdKeys {
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
_ = removeTestDNSKey(localKey)
}()
for _, key := range createdKeys {
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
if exists {
@@ -91,223 +83,13 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
err = shutdownState.Cleanup()
require.NoError(t, err)
for _, key := range createdKeys {
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
}
}
// generateShortDomains generates domains like a.com, b.com, ..., aa.com, ab.com, etc.
func generateShortDomains(count int) []string {
domains := make([]string, 0, count)
for i := range count {
label := ""
n := i
for {
label = string(rune('a'+n%26)) + label
n = n/26 - 1
if n < 0 {
break
}
}
domains = append(domains, label+".com")
}
return domains
}
// generateLongDomains generates domains like subdomain-000.department.organization-name.example.com
func generateLongDomains(count int) []string {
domains := make([]string, 0, count)
for i := range count {
domains = append(domains, fmt.Sprintf("subdomain-%03d.department.organization-name.example.com", i))
}
return domains
}
// readDomainsFromKey reads the SupplementalMatchDomains array back from scutil for a given key.
func readDomainsFromKey(t *testing.T, key string) []string {
t.Helper()
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader(fmt.Sprintf("open\nshow %s\nquit\n", key))
out, err := cmd.Output()
require.NoError(t, err, "scutil show should succeed")
var domains []string
inArray := false
scanner := bufio.NewScanner(bytes.NewReader(out))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "SupplementalMatchDomains") && strings.Contains(line, "<array>") {
inArray = true
continue
}
if inArray {
if line == "}" {
break
}
// lines look like: "0 : a.com"
parts := strings.SplitN(line, " : ", 2)
if len(parts) == 2 {
domains = append(domains, parts[1])
}
}
}
require.NoError(t, scanner.Err())
return domains
}
func TestSplitDomainsIntoBatches(t *testing.T) {
tests := []struct {
name string
domains []string
expectedCount int
checkAllPresent bool
}{
{
name: "empty",
domains: nil,
expectedCount: 0,
},
{
name: "under_limit",
domains: generateShortDomains(10),
expectedCount: 1,
checkAllPresent: true,
},
{
name: "at_element_limit",
domains: generateShortDomains(50),
expectedCount: 1,
checkAllPresent: true,
},
{
name: "over_element_limit",
domains: generateShortDomains(51),
expectedCount: 2,
checkAllPresent: true,
},
{
name: "triple_element_limit",
domains: generateShortDomains(150),
expectedCount: 3,
checkAllPresent: true,
},
{
name: "long_domains_hit_byte_limit",
domains: generateLongDomains(50),
checkAllPresent: true,
},
{
name: "500_short_domains",
domains: generateShortDomains(500),
expectedCount: 10,
checkAllPresent: true,
},
{
name: "500_long_domains",
domains: generateLongDomains(500),
checkAllPresent: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
batches := splitDomainsIntoBatches(tc.domains)
if tc.expectedCount > 0 {
assert.Len(t, batches, tc.expectedCount, "expected %d batches", tc.expectedCount)
}
// Verify each batch respects limits
for i, batch := range batches {
assert.LessOrEqual(t, len(batch), maxDomainsPerResolverEntry,
"batch %d exceeds element limit", i)
totalBytes := 0
for j, d := range batch {
if j > 0 {
totalBytes++
}
totalBytes += len(d)
}
assert.LessOrEqual(t, totalBytes, maxDomainBytesPerResolverEntry,
"batch %d exceeds byte limit (%d bytes)", i, totalBytes)
}
if tc.checkAllPresent {
var all []string
for _, batch := range batches {
all = append(all, batch...)
}
assert.Equal(t, tc.domains, all, "all domains should be present in order")
}
})
}
}
// TestMatchDomainBatching writes increasing numbers of domains via the batching mechanism
// and verifies all domains are readable across multiple scutil keys.
func TestMatchDomainBatching(t *testing.T) {
if testing.Short() {
t.Skip("skipping scutil integration test in short mode")
}
testCases := []struct {
name string
count int
generator func(int) []string
}{
{"short_10", 10, generateShortDomains},
{"short_50", 50, generateShortDomains},
{"short_100", 100, generateShortDomains},
{"short_200", 200, generateShortDomains},
{"short_500", 500, generateShortDomains},
{"long_10", 10, generateLongDomains},
{"long_50", 50, generateLongDomains},
{"long_100", 100, generateLongDomains},
{"long_200", 200, generateLongDomains},
{"long_500", 500, generateLongDomains},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
defer func() {
for key := range configurator.createdKeys {
_ = removeTestDNSKey(key)
}
}()
domains := tc.generator(tc.count)
err := configurator.addBatchedDomains(matchSuffix, domains, netip.MustParseAddr("100.64.0.1"), 53, false)
require.NoError(t, err)
batches := splitDomainsIntoBatches(domains)
t.Logf("wrote %d domains across %d batched keys", tc.count, len(batches))
// Read back all domains from all batched keys
var got []string
for i := range batches {
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i)
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
require.True(t, exists, "key %s should exist", key)
got = append(got, readDomainsFromKey(t, key)...)
}
t.Logf("read back %d/%d domains from %d keys", len(got), tc.count, len(batches))
assert.Equal(t, tc.count, len(got), "all domains should be readable")
assert.Equal(t, domains, got, "domains should match in order")
})
}
}
func checkDNSKeyExists(key string) (bool, error) {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
@@ -376,15 +158,15 @@ func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Man
createdKeys: make(map[string]struct{}),
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
cleanup := func() {
_ = sm.Stop(context.Background())
for key := range configurator.createdKeys {
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
// Also clean up old-format keys and local key in case they exist
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix))
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix))
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix))
}
return configurator, sm, cleanup

View File

@@ -277,7 +277,7 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
}
}
log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains))
log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains)
return ruleIndex, nil
}

Some files were not shown because too many files have changed in this diff Show More