Compare commits

..

3 Commits

Author SHA1 Message Date
pascal
e8156ecbb6 add cert 2026-03-04 17:30:21 +01:00
pascal
e14ddaad57 expose fileserver 2026-03-04 15:37:23 +01:00
pascal
65e627febc add static file download 2026-03-04 15:26:19 +01:00
437 changed files with 5409 additions and 45017 deletions

View File

@@ -1,14 +0,0 @@
blank_issues_enabled: true
contact_links:
- name: Community Support
url: https://forum.netbird.io/
about: Community support forum
- name: Cloud Support
url: https://docs.netbird.io/help/report-bug-issues
about: Contact us for support
- name: Client/Connection Troubleshooting
url: https://docs.netbird.io/help/troubleshooting-client
about: See our client troubleshooting guide for help addressing common issues
- name: Self-host Troubleshooting
url: https://docs.netbird.io/selfhosted/troubleshooting
about: See our self-host troubleshooting guide for help addressing common issues

View File

@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -63,15 +63,10 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- name: Generate test script
run: |
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "${{ github.workspace }}\run-tests.cmd"
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -1,51 +0,0 @@
name: PR Title Check
on:
pull_request:
types: [opened, edited, synchronize, reopened]
jobs:
check-title:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@v7
with:
script: |
const title = context.payload.pull_request.title;
const allowedTags = [
'management',
'client',
'signal',
'proxy',
'relay',
'misc',
'infrastructure',
'self-hosted',
'doc',
];
const pattern = /^\[([^\]]+)\]\s+.+/;
const match = title.match(pattern);
if (!match) {
core.setFailed(
`PR title must start with a tag in brackets.\n` +
`Example: [client] fix something\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
const invalid = tags.filter(t => !allowedTags.includes(t));
if (invalid.length > 0) {
core.setFailed(
`Invalid tag(s): ${invalid.join(', ')}\n` +
`Allowed tags: ${allowedTags.join(', ')}`
);
return;
}
console.log(`Valid PR title tags: [${tags.join(', ')}]`);

View File

@@ -10,7 +10,7 @@ on:
env:
SIGN_PIPE_VER: "v0.1.1"
GORELEASER_VER: "v2.14.3"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -169,14 +169,6 @@ jobs:
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
- name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64
@@ -194,54 +186,18 @@ jobs:
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
- name: Verify RPM signatures
- name: Tag and push PR images (amd64 only)
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
dnf install -y -q rpm-sign curl >/dev/null 2>&1
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
rpm --import /tmp/rpm-pub.key
echo "=== Verifying RPM signatures ==="
for rpm_file in /dist/*amd64*.rpm; do
[ -f "$rpm_file" ] || continue
echo "--- $(basename $rpm_file) ---"
rpm -K "$rpm_file"
done
'
- name: Clean up GPG key
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: Tag and push images (amd64 only)
if: |
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
run: |
resolve_tags() {
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "pr-${{ github.event.pull_request.number }}"
else
echo "main sha-$(git rev-parse --short HEAD)"
fi
}
tag_and_push() {
local src="$1" img_name tag dst
img_name="${src%%:*}"
for tag in $(resolve_tags); do
dst="${img_name}:${tag}"
echo "Tagging ${src} -> ${dst}"
docker tag "$src" "$dst"
docker push "$dst"
done
}
export -f tag_and_push resolve_tags
PR_TAG="pr-${{ github.event.pull_request.number }}"
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
grep '^ghcr.io/' | while read -r SRC; do
tag_and_push "$SRC"
IMG_NAME="${SRC%%:*}"
DST="${IMG_NAME}:${PR_TAG}"
echo "Tagging ${SRC} -> ${DST}"
docker tag "$SRC" "$DST"
docker push "$DST"
done
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
@@ -309,14 +265,6 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Decode GPG signing key
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
env:
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
run: |
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
- name: Install LLVM-MinGW for ARM64 cross-compilation
run: |
cd /tmp
@@ -341,24 +289,6 @@ jobs:
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
dnf install -y -q rpm-sign curl >/dev/null 2>&1
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
rpm --import /tmp/rpm-pub.key
echo "=== Verifying RPM signatures ==="
for rpm_file in /dist/*.rpm; do
[ -f "$rpm_file" ] || continue
echo "--- $(basename $rpm_file) ---"
rpm -K "$rpm_file"
done
'
- name: Clean up GPG key
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes
uses: actions/upload-artifact@v4
with:

View File

@@ -61,8 +61,8 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 58720256 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
if [ ${SIZE} -gt 57671680 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
exit 1
fi

View File

@@ -154,26 +154,6 @@ builds:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-idp-migrate
dir: tools/idp-migrate
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-idp-migrate
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
@@ -186,22 +166,18 @@ archives:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
- id: netbird-idp-migrate
builds:
- netbird-idp-migrate
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
license: BSD-3-Clause
id: netbird_deb
id: netbird-deb
bindir: /usr/bin
builds:
- netbird
formats:
- deb
scripts:
postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh"
@@ -209,19 +185,16 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client.
homepage: https://netbird.io/
license: BSD-3-Clause
id: netbird_rpm
id: netbird-rpm
bindir: /usr/bin
builds:
- netbird
formats:
- rpm
scripts:
postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh"
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
dockers:
- image_templates:
- netbirdio/netbird:{{ .Version }}-amd64
@@ -903,7 +876,7 @@ brews:
uploads:
- name: debian
ids:
- netbird_deb
- netbird-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
@@ -911,7 +884,7 @@ uploads:
- name: yum
ids:
- netbird_rpm
- netbird-rpm
mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com

View File

@@ -61,7 +61,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
homepage: https://netbird.io/
id: netbird_ui_deb
id: netbird-ui-deb
package_name: netbird-ui
builds:
- netbird-ui
@@ -80,7 +80,7 @@ nfpms:
- maintainer: Netbird <dev@netbird.io>
description: Netbird client UI.
homepage: https://netbird.io/
id: netbird_ui_rpm
id: netbird-ui-rpm
package_name: netbird-ui
builds:
- netbird-ui
@@ -95,14 +95,11 @@ nfpms:
dst: /usr/share/pixmaps/netbird.png
dependencies:
- netbird
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
uploads:
- name: debian
ids:
- netbird_ui_deb
- netbird-ui-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
@@ -110,7 +107,7 @@ uploads:
- name: yum
ids:
- netbird_ui_rpm
- netbird-ui-rpm
mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com

View File

@@ -1,7 +1,7 @@
## Contributor License Agreement
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany,
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance

View File

@@ -126,7 +126,6 @@ See a complete [architecture overview](https://docs.netbird.io/about-netbird/how
### Community projects
- [NetBird installer script](https://github.com/physk/netbird-installer)
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).

View File

@@ -17,7 +17,8 @@ ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -23,7 +23,8 @@ ENV \
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
NB_DISABLE_DNS="true" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -124,7 +124,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
@@ -157,7 +157,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
@@ -205,7 +205,7 @@ func (c *Client) PeersList() *PeerInfoArray {
pi := PeerInfo{
p.IP,
p.FQDN,
int(p.ConnStatus),
p.ConnStatus.String(),
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
}
peerInfos[n] = pi

View File

@@ -2,20 +2,11 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
// Connection status constants exported via gomobile.
const (
ConnStatusIdle = int(peer.StatusIdle)
ConnStatusConnecting = int(peer.StatusConnecting)
ConnStatusConnected = int(peer.StatusConnected)
)
// PeerInfo describe information about the peers. It designed for the UI usage
type PeerInfo struct {
IP string
FQDN string
ConnStatus int
ConnStatus string // Todo replace to enum
Routes PeerRoutes
}

View File

@@ -181,11 +181,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if stateWasDown {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up")
time.Sleep(time.Second * 10)
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("netbird up")
time.Sleep(time.Second * 10)
}
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
@@ -199,13 +198,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.")
}
needsRestoreUp := false
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
} else {
needsRestoreUp = !stateWasDown
cmd.Println("netbird down")
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("netbird down")
time.Sleep(1 * time.Second)
@@ -213,15 +209,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
Enabled: true,
}); err != nil {
cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message())
return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message())
}
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
needsRestoreUp = false
cmd.Println("netbird up")
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("netbird up")
time.Sleep(3 * time.Second)
@@ -267,28 +261,18 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
if needsRestoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up (restored)")
}
}
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird down")
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("netbird down")
}
if !initialLevelTrace {
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message())
} else {
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Printf("Local file:\n%s\n", resp.GetPath())

View File

@@ -14,9 +14,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util"
)
@@ -24,24 +22,20 @@ import (
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
var (
exposePin string
exposePassword string
exposeUserGroups []string
exposeDomain string
exposeNamePrefix string
exposeProtocol string
exposeExternalPort uint16
exposePin string
exposePassword string
exposeUserGroups []string
exposeDomain string
exposeNamePrefix string
exposeProtocol string
)
var exposeCmd = &cobra.Command{
Use: "expose <port>",
Short: "Expose a local port via the NetBird reverse proxy",
Args: cobra.ExactArgs(1),
Example: ` netbird expose --with-password safe-pass 8080
netbird expose --protocol tcp 5432
netbird expose --protocol tcp --with-external-port 5433 5432
netbird expose --protocol tls --with-custom-domain tls.example.com 4443`,
RunE: exposeFn,
Use: "expose <port>",
Short: "Expose a local port via the NetBird reverse proxy",
Args: cobra.ExactArgs(1),
Example: "netbird expose --with-password safe-pass 8080",
RunE: exposeFn,
}
func init() {
@@ -50,52 +44,7 @@ func init() {
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)")
exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)")
}
// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags.
func isClusterProtocol(protocol string) bool {
switch strings.ToLower(protocol) {
case "tcp", "udp", "tls":
return true
default:
return false
}
}
// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP)
// where domain display doesn't apply. TLS uses SNI so it has a domain.
func isPortBasedProtocol(protocol string) bool {
switch strings.ToLower(protocol) {
case "tcp", "udp":
return true
default:
return false
}
}
// extractPort returns the port portion of a URL like "tcp://host:12345", or
// falls back to the given default formatted as a string.
func extractPort(serviceURL string, fallback uint16) string {
u := serviceURL
if idx := strings.Index(u, "://"); idx != -1 {
u = u[idx+3:]
}
if i := strings.LastIndex(u, ":"); i != -1 {
if p := u[i+1:]; p != "" {
return p
}
}
return strconv.FormatUint(uint64(fallback), 10)
}
// resolveExternalPort returns the effective external port, defaulting to the target port.
func resolveExternalPort(targetPort uint64) uint16 {
if exposeExternalPort != 0 {
return exposeExternalPort
}
return uint16(targetPort)
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use, http/https is supported (e.g. --protocol http)")
}
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
@@ -108,15 +57,7 @@ func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
}
if !isProtocolValid(exposeProtocol) {
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
}
if isClusterProtocol(exposeProtocol) {
if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 {
return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol)
}
} else if cmd.Flags().Changed("with-external-port") {
return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol)
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
}
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
@@ -135,12 +76,7 @@ func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
}
func isProtocolValid(exposeProtocol string) bool {
switch strings.ToLower(exposeProtocol) {
case "http", "https", "tcp", "udp", "tls":
return true
default:
return false
}
return strings.ToLower(exposeProtocol) == "http" || strings.ToLower(exposeProtocol) == "https"
}
func exposeFn(cmd *cobra.Command, args []string) error {
@@ -187,7 +123,7 @@ func exposeFn(cmd *cobra.Command, args []string) error {
return err
}
req := &proto.ExposeServiceRequest{
stream, err := client.ExposeService(ctx, &proto.ExposeServiceRequest{
Port: uint32(port),
Protocol: protocol,
Pin: exposePin,
@@ -195,14 +131,9 @@ func exposeFn(cmd *cobra.Command, args []string) error {
UserGroups: exposeUserGroups,
Domain: exposeDomain,
NamePrefix: exposeNamePrefix,
}
if isClusterProtocol(exposeProtocol) {
req.ListenPort = uint32(resolveExternalPort(port))
}
stream, err := client.ExposeService(ctx, req)
})
if err != nil {
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
return fmt.Errorf("expose service: %w", err)
}
if err := handleExposeReady(cmd, stream, port); err != nil {
@@ -213,60 +144,36 @@ func exposeFn(cmd *cobra.Command, args []string) error {
}
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
p, err := expose.ParseProtocolType(exposeProtocol)
if err != nil {
return 0, fmt.Errorf("invalid protocol: %w", err)
}
switch p {
case expose.ProtocolHTTP:
switch strings.ToLower(exposeProtocol) {
case "http":
return proto.ExposeProtocol_EXPOSE_HTTP, nil
case expose.ProtocolHTTPS:
case "https":
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
case expose.ProtocolTCP:
return proto.ExposeProtocol_EXPOSE_TCP, nil
case expose.ProtocolUDP:
return proto.ExposeProtocol_EXPOSE_UDP, nil
case expose.ProtocolTLS:
return proto.ExposeProtocol_EXPOSE_TLS, nil
default:
return 0, fmt.Errorf("unhandled protocol type: %d", p)
return 0, fmt.Errorf("unsupported protocol %q: only 'http' or 'https' are supported", exposeProtocol)
}
}
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
return fmt.Errorf("receive expose event: %w", err)
}
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
if !ok {
switch e := event.Event.(type) {
case *proto.ExposeServiceEvent_Ready:
cmd.Println("Service exposed successfully!")
cmd.Printf(" Name: %s\n", e.Ready.ServiceName)
cmd.Printf(" URL: %s\n", e.Ready.ServiceUrl)
cmd.Printf(" Domain: %s\n", e.Ready.Domain)
cmd.Printf(" Protocol: %s\n", exposeProtocol)
cmd.Printf(" Port: %d\n", port)
cmd.Println()
cmd.Println("Press Ctrl+C to stop exposing.")
return nil
default:
return fmt.Errorf("unexpected expose event: %T", event.Event)
}
printExposeReady(cmd, ready.Ready, port)
return nil
}
func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) {
cmd.Println("Service exposed successfully!")
cmd.Printf(" Name: %s\n", r.ServiceName)
if r.ServiceUrl != "" {
cmd.Printf(" URL: %s\n", r.ServiceUrl)
}
if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) {
cmd.Printf(" Domain: %s\n", r.Domain)
}
cmd.Printf(" Protocol: %s\n", exposeProtocol)
cmd.Printf(" Internal: %d\n", port)
if isClusterProtocol(exposeProtocol) {
cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port)))
}
if r.PortAutoAssigned && exposeExternalPort != 0 {
cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort)
}
cmd.Println()
cmd.Println("Press Ctrl+C to stop exposing.")
}
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {

View File

@@ -41,7 +41,7 @@ func init() {
defaultServiceName = "Netbird"
}
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")

View File

@@ -103,7 +103,7 @@ func (p *program) Stop(srv service.Service) error {
// Common setup for service control commands
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
// rootCmd env vars are already applied by PersistentPreRunE.
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(serviceCmd)
cmd.SetOut(cmd.OutOrStdout())

View File

@@ -119,10 +119,6 @@ var installCmd = &cobra.Command{
return err
}
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}
svcConfig, err := createServiceConfigForInstall()
if err != nil {
return err
@@ -140,10 +136,6 @@ var installCmd = &cobra.Command{
return fmt.Errorf("install service: %w", err)
}
if err := saveServiceParams(currentServiceParams()); err != nil {
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
}
cmd.Println("NetBird service has been installed")
return nil
},
@@ -195,10 +187,6 @@ This command will temporarily stop the service, update its configuration, and re
return err
}
if err := loadAndApplyServiceParams(cmd); err != nil {
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
}
wasRunning, err := isServiceRunning()
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
return fmt.Errorf("check service status: %w", err)
@@ -234,10 +222,6 @@ This command will temporarily stop the service, update its configuration, and re
return fmt.Errorf("install service with new config: %w", err)
}
if err := saveServiceParams(currentServiceParams()); err != nil {
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
}
if wasRunning {
cmd.Println("Starting NetBird service...")
if err := s.Start(); err != nil {

View File

@@ -1,201 +0,0 @@
//go:build !ios && !android
package cmd
import (
"context"
"encoding/json"
"fmt"
"maps"
"os"
"path/filepath"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/util"
)
const serviceParamsFile = "service.json"
// serviceParams holds install-time service parameters that persist across
// uninstall/reinstall cycles. Saved to <stateDir>/service.json.
type serviceParams struct {
LogLevel string `json:"log_level"`
DaemonAddr string `json:"daemon_addr"`
ManagementURL string `json:"management_url,omitempty"`
ConfigPath string `json:"config_path,omitempty"`
LogFiles []string `json:"log_files,omitempty"`
DisableProfiles bool `json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
// serviceParamsPath returns the path to the service params file.
func serviceParamsPath() string {
return filepath.Join(configs.StateDir, serviceParamsFile)
}
// loadServiceParams reads saved service parameters from disk.
// Returns nil with no error if the file does not exist.
func loadServiceParams() (*serviceParams, error) {
path := serviceParamsPath()
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil //nolint:nilnil
}
return nil, fmt.Errorf("read service params %s: %w", path, err)
}
var params serviceParams
if err := json.Unmarshal(data, &params); err != nil {
return nil, fmt.Errorf("parse service params %s: %w", path, err)
}
return &params, nil
}
// saveServiceParams writes current service parameters to disk atomically
// with restricted permissions.
func saveServiceParams(params *serviceParams) error {
path := serviceParamsPath()
if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil {
return fmt.Errorf("save service params: %w", err)
}
return nil
}
// currentServiceParams captures the current state of all package-level
// variables into a serviceParams struct.
func currentServiceParams() *serviceParams {
params := &serviceParams{
LogLevel: logLevel,
DaemonAddr: daemonAddr,
ManagementURL: managementURL,
ConfigPath: configPath,
LogFiles: logFiles,
DisableProfiles: profilesDisabled,
DisableUpdateSettings: updateSettingsDisabled,
}
if len(serviceEnvVars) > 0 {
parsed, err := parseServiceEnvVars(serviceEnvVars)
if err == nil && len(parsed) > 0 {
params.ServiceEnvVars = parsed
}
}
return params
}
// loadAndApplyServiceParams loads saved params from disk and applies them
// to any flags that were not explicitly set.
func loadAndApplyServiceParams(cmd *cobra.Command) error {
params, err := loadServiceParams()
if err != nil {
return err
}
applyServiceParams(cmd, params)
return nil
}
// applyServiceParams merges saved parameters into package-level variables
// for any flag that was not explicitly set by the user (via CLI or env var).
// Flags that were Changed() are left untouched.
func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
if params == nil {
return
}
// For fields with non-empty defaults (log-level, daemon-addr), keep the
// != "" guard so that an older service.json missing the field doesn't
// clobber the default with an empty string.
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
logLevel = params.LogLevel
}
if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" {
daemonAddr = params.DaemonAddr
}
// For optional fields where empty means "use default", always apply so
// that an explicit clear (--management-url "") persists across reinstalls.
if !rootCmd.PersistentFlags().Changed("management-url") {
managementURL = params.ManagementURL
}
if !rootCmd.PersistentFlags().Changed("config") {
configPath = params.ConfigPath
}
if !rootCmd.PersistentFlags().Changed("log-file") {
logFiles = params.LogFiles
}
if !serviceCmd.PersistentFlags().Changed("disable-profiles") {
profilesDisabled = params.DisableProfiles
}
if !serviceCmd.PersistentFlags().Changed("disable-update-settings") {
updateSettingsDisabled = params.DisableUpdateSettings
}
applyServiceEnvParams(cmd, params)
}
// applyServiceEnvParams merges saved service environment variables.
// If --service-env was explicitly set, explicit values win on key conflict
// but saved keys not in the explicit set are carried over.
// If --service-env was not set, saved env vars are used entirely.
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
if len(params.ServiceEnvVars) == 0 {
return
}
if !cmd.Flags().Changed("service-env") {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
return
}
// Explicit env vars were provided: merge saved values underneath.
explicit, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
return
}
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
maps.Copy(merged, params.ServiceEnvVars)
maps.Copy(merged, explicit) // explicit wins on conflict
serviceEnvVars = envMapToSlice(merged)
}
var resetParamsCmd = &cobra.Command{
Use: "reset-params",
Short: "Remove saved service install parameters",
Long: "Removes the saved service.json file so the next install uses default parameters.",
RunE: func(cmd *cobra.Command, args []string) error {
path := serviceParamsPath()
if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
cmd.Println("No saved service parameters found")
return nil
}
return fmt.Errorf("remove service params: %w", err)
}
cmd.Printf("Removed saved service parameters (%s)\n", path)
return nil
},
}
// envMapToSlice converts a map of env vars to a KEY=VALUE slice.
func envMapToSlice(m map[string]string) []string {
s := make([]string, 0, len(m))
for k, v := range m {
s = append(s, k+"="+v)
}
return s
}

View File

@@ -1,523 +0,0 @@
//go:build !ios && !android
package cmd
import (
"encoding/json"
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"strings"
"testing"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/configs"
)
func TestServiceParamsPath(t *testing.T) {
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = "/var/lib/netbird"
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
configs.StateDir = "/custom/state"
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
}
func TestSaveAndLoadServiceParams(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
params := &serviceParams{
LogLevel: "debug",
DaemonAddr: "unix:///var/run/netbird.sock",
ManagementURL: "https://my.server.com",
ConfigPath: "/etc/netbird/config.json",
LogFiles: []string{"/var/log/netbird/client.log", "console"},
DisableProfiles: true,
DisableUpdateSettings: false,
ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"},
}
err := saveServiceParams(params)
require.NoError(t, err)
// Verify the file exists and is valid JSON.
data, err := os.ReadFile(filepath.Join(tmpDir, "service.json"))
require.NoError(t, err)
assert.True(t, json.Valid(data))
loaded, err := loadServiceParams()
require.NoError(t, err)
require.NotNil(t, loaded)
assert.Equal(t, params.LogLevel, loaded.LogLevel)
assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr)
assert.Equal(t, params.ManagementURL, loaded.ManagementURL)
assert.Equal(t, params.ConfigPath, loaded.ConfigPath)
assert.Equal(t, params.LogFiles, loaded.LogFiles)
assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles)
assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings)
assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars)
}
func TestLoadServiceParams_FileNotExists(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
params, err := loadServiceParams()
assert.NoError(t, err)
assert.Nil(t, params)
}
func TestLoadServiceParams_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
original := configs.StateDir
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = tmpDir
err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600)
require.NoError(t, err)
params, err := loadServiceParams()
assert.Error(t, err)
assert.Nil(t, params)
}
func TestCurrentServiceParams(t *testing.T) {
origLogLevel := logLevel
origDaemonAddr := daemonAddr
origManagementURL := managementURL
origConfigPath := configPath
origLogFiles := logFiles
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() {
logLevel = origLogLevel
daemonAddr = origDaemonAddr
managementURL = origManagementURL
configPath = origConfigPath
logFiles = origLogFiles
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
serviceEnvVars = origServiceEnvVars
})
logLevel = "trace"
daemonAddr = "tcp://127.0.0.1:9999"
managementURL = "https://mgmt.example.com"
configPath = "/tmp/test-config.json"
logFiles = []string{"/tmp/test.log"}
profilesDisabled = true
updateSettingsDisabled = true
serviceEnvVars = []string{"FOO=bar", "BAZ=qux"}
params := currentServiceParams()
assert.Equal(t, "trace", params.LogLevel)
assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr)
assert.Equal(t, "https://mgmt.example.com", params.ManagementURL)
assert.Equal(t, "/tmp/test-config.json", params.ConfigPath)
assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles)
assert.True(t, params.DisableProfiles)
assert.True(t, params.DisableUpdateSettings)
assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars)
}
func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) {
origLogLevel := logLevel
origDaemonAddr := daemonAddr
origManagementURL := managementURL
origConfigPath := configPath
origLogFiles := logFiles
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() {
logLevel = origLogLevel
daemonAddr = origDaemonAddr
managementURL = origManagementURL
configPath = origConfigPath
logFiles = origLogFiles
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
serviceEnvVars = origServiceEnvVars
})
// Reset all flags to defaults.
logLevel = "info"
daemonAddr = "unix:///var/run/netbird.sock"
managementURL = ""
configPath = "/etc/netbird/config.json"
logFiles = []string{"/var/log/netbird/client.log"}
profilesDisabled = false
updateSettingsDisabled = false
serviceEnvVars = nil
// Reset Changed state on all relevant flags.
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
// Simulate user explicitly setting --log-level via CLI.
logLevel = "warn"
require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn"))
saved := &serviceParams{
LogLevel: "debug",
DaemonAddr: "tcp://127.0.0.1:5555",
ManagementURL: "https://saved.example.com",
ConfigPath: "/saved/config.json",
LogFiles: []string{"/saved/client.log"},
DisableProfiles: true,
DisableUpdateSettings: true,
ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"},
}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
// log-level was Changed, so it should keep "warn", not use saved "debug".
assert.Equal(t, "warn", logLevel)
// All other fields were not Changed, so they should use saved values.
assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr)
assert.Equal(t, "https://saved.example.com", managementURL)
assert.Equal(t, "/saved/config.json", configPath)
assert.Equal(t, []string{"/saved/client.log"}, logFiles)
assert.True(t, profilesDisabled)
assert.True(t, updateSettingsDisabled)
assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars)
}
func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) {
origProfilesDisabled := profilesDisabled
origUpdateSettingsDisabled := updateSettingsDisabled
t.Cleanup(func() {
profilesDisabled = origProfilesDisabled
updateSettingsDisabled = origUpdateSettingsDisabled
})
// Simulate current state where booleans are true (e.g. set by previous install).
profilesDisabled = true
updateSettingsDisabled = true
// Reset Changed state so flags appear unset.
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
// Saved params have both as false.
saved := &serviceParams{
DisableProfiles: false,
DisableUpdateSettings: false,
}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
assert.False(t, profilesDisabled, "saved false should override current true")
assert.False(t, updateSettingsDisabled, "saved false should override current true")
}
func TestApplyServiceParams_ClearManagementURL(t *testing.T) {
origManagementURL := managementURL
t.Cleanup(func() { managementURL = origManagementURL })
managementURL = "https://leftover.example.com"
// Simulate saved params where management URL was explicitly cleared.
saved := &serviceParams{
LogLevel: "info",
DaemonAddr: "unix:///var/run/netbird.sock",
// ManagementURL intentionally empty: was cleared with --management-url "".
}
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
f.Changed = false
})
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
applyServiceParams(cmd, saved)
assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value")
}
func TestApplyServiceParams_NilParams(t *testing.T) {
origLogLevel := logLevel
t.Cleanup(func() { logLevel = origLogLevel })
logLevel = "info"
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
// Should be a no-op.
applyServiceParams(cmd, nil)
assert.Equal(t, "info", logLevel)
}
func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Set up a command with --service-env marked as Changed.
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit"))
serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"}
saved := &serviceParams{
ServiceEnvVars: map[string]string{
"SAVED": "val",
"OVERLAP": "saved",
},
}
applyServiceEnvParams(cmd, saved)
// Parse result for easier assertion.
result, err := parseServiceEnvVars(serviceEnvVars)
require.NoError(t, err)
assert.Equal(t, "yes", result["EXPLICIT"])
assert.Equal(t, "val", result["SAVED"])
// Explicit wins on conflict.
assert.Equal(t, "explicit", result["OVERLAP"])
}
func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
serviceEnvVars = nil
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
saved := &serviceParams{
ServiceEnvVars: map[string]string{"FROM_SAVED": "val"},
}
applyServiceEnvParams(cmd, saved)
result, err := parseServiceEnvVars(serviceEnvVars)
require.NoError(t, err)
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
}
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
// added to serviceParams but not wired into these functions, this test fails.
func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
require.NoError(t, err)
// Collect all JSON field names from the serviceParams struct.
structFields := extractStructJSONFields(t, file, "serviceParams")
require.NotEmpty(t, structFields, "failed to find serviceParams struct fields")
// Collect field names referenced in currentServiceParams and applyServiceParams.
currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields)
applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields)
// applyServiceEnvParams handles ServiceEnvVars indirectly.
applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields)
for k, v := range applyEnvFields {
applyFields[k] = v
}
for _, field := range structFields {
assert.Contains(t, currentFields, field,
"serviceParams field %q is not captured in currentServiceParams()", field)
assert.Contains(t, applyFields, field,
"serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field)
}
}
// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references
// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because
// it flows through newSVCConfig() EnvVars, not CLI args.
func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
require.NoError(t, err)
structFields := extractStructJSONFields(t, file, "serviceParams")
require.NotEmpty(t, structFields)
installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0)
require.NoError(t, err)
// Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig).
fieldsNotInArgs := map[string]bool{
"ServiceEnvVars": true,
}
buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments")
// Forward: every struct field must appear in buildServiceArguments.
for _, field := range structFields {
if fieldsNotInArgs[field] {
continue
}
globalVar := fieldToGlobalVar(field)
assert.Contains(t, buildFields, globalVar,
"serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar)
}
// Reverse: every service-related global used in buildServiceArguments must
// have a corresponding serviceParams field. This catches a developer adding
// a new flag to buildServiceArguments without adding it to the struct.
globalToField := make(map[string]string, len(structFields))
for _, field := range structFields {
globalToField[fieldToGlobalVar(field)] = field
}
// Identifiers in buildServiceArguments that are not service params
// (builtins, boilerplate, loop variables).
nonParamGlobals := map[string]bool{
"args": true, "append": true, "string": true, "_": true,
"logFile": true, // range variable over logFiles
}
for ref := range buildFields {
if nonParamGlobals[ref] {
continue
}
_, inStruct := globalToField[ref]
assert.True(t, inStruct,
"buildServiceArguments() references global %q which has no corresponding serviceParams field", ref)
}
}
// extractStructJSONFields returns field names from a named struct type.
func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string {
t.Helper()
var fields []string
ast.Inspect(file, func(n ast.Node) bool {
ts, ok := n.(*ast.TypeSpec)
if !ok || ts.Name.Name != structName {
return true
}
st, ok := ts.Type.(*ast.StructType)
if !ok {
return false
}
for _, f := range st.Fields.List {
if len(f.Names) > 0 {
fields = append(fields, f.Names[0].Name)
}
}
return false
})
return fields
}
// extractFuncFieldRefs returns which of the given field names appear inside the
// named function, either as selector expressions (params.FieldName) or as
// composite literal keys (&serviceParams{FieldName: ...}).
func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool {
t.Helper()
fieldSet := make(map[string]bool, len(fields))
for _, f := range fields {
fieldSet[f] = true
}
found := make(map[string]bool)
fn := findFuncDecl(file, funcName)
require.NotNil(t, fn, "function %s not found", funcName)
ast.Inspect(fn.Body, func(n ast.Node) bool {
switch v := n.(type) {
case *ast.SelectorExpr:
if fieldSet[v.Sel.Name] {
found[v.Sel.Name] = true
}
case *ast.KeyValueExpr:
if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] {
found[ident.Name] = true
}
}
return true
})
return found
}
// extractFuncGlobalRefs returns all identifier names referenced in the named function body.
func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool {
t.Helper()
fn := findFuncDecl(file, funcName)
require.NotNil(t, fn, "function %s not found", funcName)
refs := make(map[string]bool)
ast.Inspect(fn.Body, func(n ast.Node) bool {
if ident, ok := n.(*ast.Ident); ok {
refs[ident.Name] = true
}
return true
})
return refs
}
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
for _, decl := range file.Decls {
fn, ok := decl.(*ast.FuncDecl)
if ok && fn.Name.Name == name {
return fn
}
}
return nil
}
// fieldToGlobalVar maps serviceParams field names to the package-level variable
// names used in buildServiceArguments and applyServiceParams.
func fieldToGlobalVar(field string) string {
m := map[string]string{
"LogLevel": "logLevel",
"DaemonAddr": "daemonAddr",
"ManagementURL": "managementURL",
"ConfigPath": "configPath",
"LogFiles": "logFiles",
"DisableProfiles": "profilesDisabled",
"DisableUpdateSettings": "updateSettingsDisabled",
"ServiceEnvVars": "serviceEnvVars",
}
if v, ok := m[field]; ok {
return v
}
// Default: lowercase first letter.
return strings.ToLower(field[:1]) + field[1:]
}
func TestEnvMapToSlice(t *testing.T) {
m := map[string]string{"A": "1", "B": "2"}
s := envMapToSlice(m)
assert.Len(t, s, 2)
assert.Contains(t, s, "A=1")
assert.Contains(t, s, "B=2")
}
func TestEnvMapToSlice_Empty(t *testing.T) {
s := envMapToSlice(map[string]string{})
assert.Empty(t, s)
}

View File

@@ -4,9 +4,7 @@ import (
"context"
"fmt"
"os"
"os/signal"
"runtime"
"syscall"
"testing"
"time"
@@ -15,22 +13,6 @@ import (
"github.com/stretchr/testify/require"
)
// TestMain intercepts when this test binary is run as a daemon subprocess.
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
// "service run ..." arguments. Since the test binary can't handle cobra CLI
// args, it exits immediately, causing daemon -r to respawn rapidly until
// hitting the rate limit and exiting. This makes service restart unreliable.
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
<-sig
return
}
os.Exit(m.Run())
}
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
@@ -97,34 +79,6 @@ func TestServiceLifecycle(t *testing.T) {
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (

View File

@@ -6,7 +6,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (

View File

@@ -28,7 +28,6 @@ var (
ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string
checkFlag string
)
var statusCmd = &cobra.Command{
@@ -50,7 +49,6 @@ func init() {
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
}
func statusFunc(cmd *cobra.Command, args []string) error {
@@ -58,10 +56,6 @@ func statusFunc(cmd *cobra.Command, args []string) error {
cmd.SetOut(cmd.OutOrStdout())
if checkFlag != "" {
return runHealthCheck(cmd)
}
err := parseFilters()
if err != nil {
return err
@@ -74,17 +68,15 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx, true, false)
resp, err := getStatus(ctx, false)
if err != nil {
return err
}
status := resp.GetStatus()
needsAuth := status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired)
if needsAuth && !jsonFlag && !yamlFlag {
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) {
cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+
@@ -107,17 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
Anonymize: anonymizeFlag,
DaemonVersion: resp.GetDaemonVersion(),
DaemonStatus: nbstatus.ParseDaemonStatus(status),
StatusFilter: statusFilter,
PrefixNamesFilter: prefixNamesFilter,
PrefixNamesFilterMap: prefixNamesFilterMap,
IPsFilter: ipsFilterMap,
ConnectionTypeFilter: connectionTypeFilter,
ProfileName: profName,
})
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var statusOutputString string
switch {
case detailFlag:
@@ -139,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}
func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
@@ -149,7 +131,7 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: fullPeerStatus, ShouldRunProbes: shouldRunProbes})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
@@ -203,83 +185,6 @@ func enableDetailFlagWhenFilterFlag() {
}
}
func runHealthCheck(cmd *cobra.Command) error {
check := strings.ToLower(checkFlag)
switch check {
case "live", "ready", "startup":
default:
return fmt.Errorf("unknown check %q, must be one of: live, ready, startup", checkFlag)
}
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := internal.CtxInitState(cmd.Context())
isStartup := check == "startup"
resp, err := getStatus(ctx, isStartup, false)
if err != nil {
return err
}
switch check {
case "live":
return nil
case "ready":
return checkReadiness(resp)
case "startup":
return checkStartup(resp)
default:
return nil
}
}
func checkReadiness(resp *proto.StatusResponse) error {
daemonStatus := internal.StatusType(resp.GetStatus())
switch daemonStatus {
case internal.StatusIdle, internal.StatusConnecting, internal.StatusConnected:
return nil
case internal.StatusNeedsLogin, internal.StatusLoginFailed, internal.StatusSessionExpired:
return fmt.Errorf("readiness check: daemon status is %s", daemonStatus)
default:
return fmt.Errorf("readiness check: unexpected daemon status %q", daemonStatus)
}
}
func checkStartup(resp *proto.StatusResponse) error {
fullStatus := resp.GetFullStatus()
if fullStatus == nil {
return fmt.Errorf("startup check: no full status available")
}
if !fullStatus.GetManagementState().GetConnected() {
return fmt.Errorf("startup check: management not connected")
}
if !fullStatus.GetSignalState().GetConnected() {
return fmt.Errorf("startup check: signal not connected")
}
var relayCount, relaysConnected int
for _, r := range fullStatus.GetRelays() {
uri := r.GetURI()
if !strings.HasPrefix(uri, "rel://") && !strings.HasPrefix(uri, "rels://") {
continue
}
relayCount++
if r.GetAvailable() {
relaysConnected++
}
}
if relayCount > 0 && relaysConnected == 0 {
return fmt.Errorf("startup check: no relay servers available (0/%d connected)", relayCount)
}
return nil
}
func parseInterfaceIP(interfaceIP string) string {
ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil {

View File

@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))

View File

@@ -11,7 +11,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updater/installer"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/util"
)

View File

@@ -14,7 +14,6 @@ import (
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
@@ -22,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -33,14 +31,14 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized")
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
const (
// PeerStatusConnected indicates the peer is in connected state.
PeerStatusConnected = peer.StatusConnected
)
// PeerConnStatus is a peer's connection status.
type PeerConnStatus = peer.ConnStatus
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string
@@ -83,14 +81,6 @@ type Options struct {
BlockInbound bool
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
WireguardPort *int
// MTU is the MTU for the WireGuard interface.
// Valid values are in the range 576..8192 bytes.
// If non-nil, this value overrides any value stored in the config file.
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
// Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams.
MTU *uint16
// DNSLabels defines additional DNS labels configured in the peer.
DNSLabels []string
}
// validateCredentials checks that exactly one credential type is provided
@@ -122,12 +112,6 @@ func New(opts Options) (*Client, error) {
return nil, err
}
if opts.MTU != nil {
if err := iface.ValidateMTU(*opts.MTU); err != nil {
return nil, fmt.Errorf("invalid MTU: %w", err)
}
}
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
@@ -156,14 +140,9 @@ func New(opts Options) (*Client, error) {
}
}
var err error
var parsedLabels domain.List
if parsedLabels, err = domain.FromStringList(opts.DNSLabels); err != nil {
return nil, fmt.Errorf("invalid dns labels: %w", err)
}
t := true
var config *profilemanager.Config
var err error
input := profilemanager.ConfigInput{
ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL,
@@ -172,8 +151,6 @@ func New(opts Options) (*Client, error) {
DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort,
MTU: opts.MTU,
DNSLabels: parsedLabels,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)
@@ -225,7 +202,7 @@ func (c *Client) Start(startCtx context.Context) error {
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
client := internal.NewConnectClient(ctx, c.config, c.recorder)
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
@@ -375,32 +352,6 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL.
// It returns an ExposeSession. Call Wait on the session to keep it alive.
func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
mgr := engine.GetExposeManager()
if mgr == nil {
return nil, fmt.Errorf("expose manager not available")
}
resp, err := mgr.Expose(ctx, req)
if err != nil {
return nil, fmt.Errorf("expose: %w", err)
}
return &ExposeSession{
Domain: resp.Domain,
ServiceName: resp.ServiceName,
ServiceURL: resp.ServiceURL,
mgr: mgr,
}, nil
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()

View File

@@ -1,45 +0,0 @@
package embed
import (
"context"
"errors"
"github.com/netbirdio/netbird/client/internal/expose"
)
const (
// ExposeProtocolHTTP exposes the service as HTTP.
ExposeProtocolHTTP = expose.ProtocolHTTP
// ExposeProtocolHTTPS exposes the service as HTTPS.
ExposeProtocolHTTPS = expose.ProtocolHTTPS
// ExposeProtocolTCP exposes the service as TCP.
ExposeProtocolTCP = expose.ProtocolTCP
// ExposeProtocolUDP exposes the service as UDP.
ExposeProtocolUDP = expose.ProtocolUDP
// ExposeProtocolTLS exposes the service as TLS.
ExposeProtocolTLS = expose.ProtocolTLS
)
// ExposeRequest is a request to expose a local service via the NetBird reverse proxy.
type ExposeRequest = expose.Request
// ExposeProtocolType represents the protocol used for exposing a service.
type ExposeProtocolType = expose.ProtocolType
// ExposeSession represents an active expose session. Use Wait to block until the session ends.
type ExposeSession struct {
Domain string
ServiceName string
ServiceURL string
mgr *expose.Manager
}
// Wait blocks while keeping the expose session alive.
// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session.
func (s *ExposeSession) Wait(ctx context.Context) error {
if s == nil || s.mgr == nil {
return errors.New("expose session is not initialized")
}
return s.mgr.KeepAlive(ctx, s.Domain)
}

View File

@@ -23,10 +23,9 @@ type Manager struct {
wgIface iFaceMapper
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *router
rawSupported bool
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *router
}
// iFaceMapper defines subset methods of interface required for manager
@@ -85,7 +84,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
if err := m.initNoTrackChain(); err != nil {
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
return fmt.Errorf("init notrack chain: %w", err)
}
// persist early to ensure cleanup of chains
@@ -286,22 +285,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"
@@ -335,10 +318,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if !m.rawSupported {
return fmt.Errorf("raw table not available")
}
wgPortStr := fmt.Sprintf("%d", wgPort)
proxyPortStr := fmt.Sprintf("%d", proxyPort)
@@ -396,16 +375,12 @@ func (m *Manager) initNoTrackChain() error {
return fmt.Errorf("add prerouting jump rule: %w", err)
}
m.rawSupported = true
return nil
}
func (m *Manager) cleanupNoTrackChain() error {
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
if err != nil {
if !m.rawSupported {
return nil
}
return fmt.Errorf("check chain exists: %w", err)
}
if !exists {
@@ -426,7 +401,6 @@ func (m *Manager) cleanupNoTrackChain() error {
return fmt.Errorf("clear and delete chain: %w", err)
}
m.rawSupported = false
return nil
}

View File

@@ -36,7 +36,6 @@ const (
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
chainNATOutput = "NETBIRD-NAT-OUTPUT"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
@@ -44,7 +43,6 @@ const (
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
jumpNatOutput = "jump-nat-output"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
@@ -389,14 +387,6 @@ func (r *router) cleanUpDefaultForwardRules() error {
}
log.Debug("flushing routing related tables")
// Remove jump rules from built-in chains before deleting custom chains,
// otherwise the chain deletion fails with "device or resource busy".
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
log.Debugf("clean OUTPUT jump rule: %v", err)
}
for _, chainInfo := range []struct {
chain string
table string
@@ -406,7 +396,6 @@ func (r *router) cleanUpDefaultForwardRules() error {
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainNATOutput, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
@@ -981,81 +970,6 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.rules[jumpNatOutput]; exists {
return nil
}
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
}
if !chainExists {
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
}
}
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
if !chainExists {
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
}
}
return fmt.Errorf("add OUTPUT jump rule: %w", err)
}
r.rules[jumpNatOutput] = jumpRule
r.updateState()
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
dnatRule := []string{
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
r.updateState()
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil

View File

@@ -169,14 +169,6 @@ type Manager interface {
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error

View File

@@ -95,7 +95,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
if err := m.initNoTrackChains(workTable); err != nil {
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
return fmt.Errorf("init notrack chains: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
@@ -346,22 +346,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRawOutput = "netbird-raw-out"
chainNameRawPrerouting = "netbird-raw-pre"

View File

@@ -36,7 +36,6 @@ const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameNATOutput = "netbird-nat-output"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
@@ -1854,130 +1853,6 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.chains[chainNameNATOutput]; exists {
return nil
}
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
Name: chainNameNATOutput,
Table: r.workTable,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
if err := r.conn.Flush(); err != nil {
delete(r.chains, chainNameNATOutput)
return fmt.Errorf("create NAT output chain: %w", err)
}
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameNATOutput],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,

View File

@@ -140,17 +140,6 @@ type Manager struct {
mtu uint16
mssClampValue uint16
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[packetHook]
tcpHookOut atomic.Pointer[packetHook]
}
// packetHook stores a registered hook for a specific IP:port.
type packetHook struct {
ip netip.Addr
port uint16
fn func([]byte) bool
}
// decoder for packages
@@ -605,8 +594,6 @@ func (m *Manager) resetState() {
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
m.tcpHookOut.Store(nil)
if m.udpTracker != nil {
m.udpTracker.Close()
@@ -726,9 +713,6 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
case layers.LayerTypeTCP:
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
return true
}
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
@@ -911,21 +895,38 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
d.dnatOrigPort = 0
}
// udpHooksDrop checks if any UDP hooks should drop the packet
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
}
m.mutex.RLock()
defer m.mutex.RUnlock()
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
// Check specific destination IP first
if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
// Check IPv4 unspecified address
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
if h.ip == dstIP && h.port == dport {
return h.fn(packetData)
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false
}
@@ -1277,6 +1278,12 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook
if rule.udpHook != nil {
return rule.mgmtId, rule.udpHook(packetData), true
}
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
@@ -1335,30 +1342,65 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return sourceMatched
}
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.udpHookOut.Store(nil)
return
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
r := PeerRule{
id: uuid.New().String(),
ip: ip,
protoLayer: layers.LayerTypeUDP,
dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6,
udpHook: hook,
}
m.udpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
if in {
// Incoming UDP hooks are stored in allow rules map
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
}
m.incomingRules[r.ip][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip] = make(map[string]PeerRule)
}
m.outgoingRules[r.ip][r.id] = r
}
m.mutex.Unlock()
return r.id
}
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.tcpHookOut.Store(nil)
return
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// Check incoming hooks (stored in allow rules)
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
}
m.tcpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
// Check outgoing hooks
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
}
return fmt.Errorf("hook with given id not found")
}
// SetLogLevel sets the log level for the firewall manager

View File

@@ -12,7 +12,6 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
@@ -187,52 +186,81 @@ func TestManagerDeleteRule(t *testing.T) {
}
}
func TestSetUDPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
func TestAddUDPPacketHook(t *testing.T) {
tests := []struct {
name string
in bool
expDir fw.RuleDirection
ip netip.Addr
dPort uint16
hook func([]byte) bool
expectedID string
}{
{
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000,
hook: func([]byte) bool { return true },
},
{
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: netip.MustParseAddr("::1"),
dPort: 9000,
hook: func([]byte) bool { return false },
},
}
var called bool
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
called = true
return true
})
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
h := manager.udpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(8000), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
assert.Nil(t, manager.udpHookOut.Load())
}
var addedRule PeerRule
if tt.in {
// Incoming UDP hooks are stored in allow rules map
if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
return
}
for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule
}
} else {
if len(manager.outgoingRules[tt.ip]) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
return
}
for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule
}
}
func TestSetTCPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
var called bool
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
called = true
return true
})
h := manager.tcpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(53), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
assert.Nil(t, manager.tcpHookOut.Load())
if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
if tt.dPort != addedRule.dPort.Values[0] {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
return
}
if layers.LayerTypeUDP != addedRule.protoLayer {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return
}
if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set")
return
}
})
}
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
@@ -502,12 +530,39 @@ func TestRemovePacketHook(t *testing.T) {
require.NoError(t, manager.Close(nil))
}()
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
found = true
break
}
}
}
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
if !found {
t.Fatalf("The hook was not added properly.")
}
// Now remove the packet hook
err = manager.RemovePacketHook(hookID)
if err != nil {
t.Fatalf("Failed to remove hook: %s", err)
}
// Assert the hook is removed by checking it in the manager's outgoing rules
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
}
}
}
func TestProcessOutgoingHooks(t *testing.T) {
@@ -537,7 +592,8 @@ func TestProcessOutgoingHooks(t *testing.T) {
}
hookCalled := false
manager.SetUDPPacketHook(
hookID := manager.AddUDPPacketHook(
false,
netip.MustParseAddr("100.10.0.100"),
53,
func([]byte) bool {
@@ -545,6 +601,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
return true
},
)
require.NotEmpty(t, hookID)
// Create test UDP packet
ipv4 := &layers.IPv4{

View File

@@ -144,8 +144,6 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
// case where an interface comes up between refreshes.
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}

View File

@@ -421,7 +421,6 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
// TODO: also delegate to nativeFirewall when available for kernel WG mode
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
@@ -467,22 +466,6 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// AddOutputDNAT delegates to the native firewall if available.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return fmt.Errorf("output DNAT not supported without native firewall")
}
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT delegates to the native firewall if available.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {

View File

@@ -18,7 +18,9 @@ type PeerRule struct {
protoLayer gopacket.LayerType
sPort *firewall.Port
dPort *firewall.Port
drop bool
drop bool
udpHook func([]byte) bool
}
// ID returns the rule id

View File

@@ -399,17 +399,21 @@ func TestTracePacket(t *testing.T) {
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
return true // drop (intercepted by hook)
})
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,

View File

@@ -28,7 +28,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
// CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string, extraOpts ...grpc.DialOption) (*grpc.ClientConn, error) {
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
// for js, the outer websocket layer takes care of tls
if tlsEnabled && runtime.GOOS != "js" {
@@ -46,7 +46,9 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
opts := []grpc.DialOption{
conn, err := grpc.DialContext(
connCtx,
addr,
transportOption,
WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
@@ -54,10 +56,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
}
opts = append(opts, extraOpts...)
conn, err := grpc.DialContext(connCtx, addr, opts...)
)
if err != nil {
return nil, fmt.Errorf("dial context: %w", err)
}

View File

@@ -15,17 +15,14 @@ type PacketFilter interface {
// FilterInbound filter incoming packets from external sources to host
FilterInbound(packetData []byte, size int) bool
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
// Hook function returns true if the packet should be dropped.
// Only one UDP hook is supported; calling again replaces the previous hook.
// Pass nil hook to remove.
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
// Hook function returns true if the packet should be dropped.
// Only one TCP hook is supported; calling again replaces the previous hook.
// Pass nil hook to remove.
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
}
// FilteredDevice to override Read or Write of packets

View File

@@ -34,28 +34,18 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// SetUDPPacketHook mocks base method.
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string)
return ret0
}
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
}
// SetTCPPacketHook mocks base method.
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
}
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
@@ -85,3 +75,17 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
}
// RemovePacketHook mocks base method.
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RemovePacketHook indicates an expected call of RemovePacketHook.
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
}

View File

@@ -0,0 +1,87 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
import (
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockPacketFilter is a mock of PacketFilter interface.
type MockPacketFilter struct {
ctrl *gomock.Controller
recorder *MockPacketFilterMockRecorder
}
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
type MockPacketFilterMockRecorder struct {
mock *MockPacketFilter
}
// NewMockPacketFilter creates a new mock instance.
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
mock := &MockPacketFilter{ctrl: ctrl}
mock.recorder = &MockPacketFilterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
}
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterInbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
}
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
err := a.doMgmLogin(client, ctx, pubSSHKey)
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) {
needsLogin = true
return nil
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
err := a.doMgmLogin(client, ctx, pubSSHKey)
if isRegistrationNeeded(err) {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
@@ -201,7 +201,13 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
protoFlow, err := client.GetPKCEAuthorizationFlow()
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
@@ -215,7 +221,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
ClientSecret: protoConfig.GetClientSecret(),
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
@@ -240,7 +246,13 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
protoFlow, err := client.GetDeviceAuthorizationFlow()
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
@@ -254,7 +266,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
ClientSecret: protoConfig.GetClientSecret(),
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
@@ -280,16 +292,28 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo)
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
return err
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
@@ -298,7 +322,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
a.setSystemInfoFlags(info)
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err

View File

@@ -23,13 +23,12 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/metrics"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/updater"
"github.com/netbirdio/netbird/client/internal/updater/installer"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
@@ -44,19 +43,14 @@ import (
"github.com/netbirdio/netbird/version"
)
// androidRunOverride is set on Android to inject mobile dependencies
// when using embed.Client (which calls Run() with empty MobileDependency).
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
doInitialAutoUpdate bool
engine *Engine
engineMutex sync.Mutex
clientMetrics *metrics.ClientMetrics
updateManager *updater.Manager
engine *Engine
engineMutex sync.Mutex
persistSyncResponse bool
}
@@ -65,24 +59,19 @@ func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
doInitialAutoUpdate: doInitalAutoUpdate,
engineMutex: sync.Mutex{},
}
}
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
c.updateManager = um
}
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
if androidRunOverride != nil {
return androidRunOverride(c, runningChan, logPath)
}
return c.run(MobileDependency{}, runningChan, logPath)
}
@@ -111,7 +100,6 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
dnsAddresses []netip.AddrPort,
stateFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
@@ -121,7 +109,6 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
HostDNSAddresses: dnsAddresses,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")
@@ -144,34 +131,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
}()
// Stop metrics push on exit
defer func() {
if c.clientMetrics != nil {
c.clientMetrics.StopPush()
}
}()
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
nbnet.Init()
// Initialize metrics once at startup (always active for debug bundles)
if c.clientMetrics == nil {
agentInfo := metrics.AgentInfo{
DeploymentType: metrics.DeploymentTypeUnknown,
Version: version.NetbirdVersion(),
OS: runtime.GOOS,
Arch: runtime.GOARCH,
}
c.clientMetrics = metrics.NewClientMetrics(agentInfo)
log.Debugf("initialized client metrics")
// Start metrics push if enabled (uses daemon context, persists across engine restarts)
if metrics.IsMetricsPushEnabled() {
c.clientMetrics.StartPush(c.ctx, metrics.PushConfigFromEnv())
}
}
backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second,
RandomizationFactor: 1,
@@ -224,13 +187,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
stateManager := statemanager.New(path)
stateManager.RegisterState(&sshconfig.ShutdownState{})
if c.updateManager != nil {
c.updateManager.CheckUpdateSuccess(c.ctx)
}
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
if err == nil {
updateManager.CheckUpdateSuccess(c.ctx)
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
}
}
defer c.statusRecorder.ClientStop()
@@ -258,16 +222,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier)
// Update metrics with actual deployment type after connection
deploymentType := metrics.DetermineDeploymentType(mgmClient.GetServerURL())
agentInfo := metrics.AgentInfo{
DeploymentType: deploymentType,
Version: version.NetbirdVersion(),
OS: runtime.GOOS,
Arch: runtime.GOARCH,
}
c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String())
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
defer func() {
if err = mgmClient.Close(); err != nil {
@@ -276,10 +230,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}()
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
loginStarted := time.Now()
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
if err != nil {
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false)
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
@@ -288,7 +240,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
return wrapErr(err)
}
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), true)
c.statusRecorder.MarkManagementConnected()
localPeerState := peer.LocalPeerState{
@@ -357,16 +308,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
engine := NewEngine(engineCtx, cancel, engineConfig, EngineServices{
SignalClient: signalClient,
MgmClient: mgmClient,
RelayManager: relayManager,
StatusRecorder: c.statusRecorder,
Checks: checks,
StateManager: stateManager,
UpdateManager: c.updateManager,
ClientMetrics: c.clientMetrics,
}, mobileDependency)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
@@ -376,6 +318,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
// AutoUpdate will be true when the user click on "Connect" menu on the UI
if c.doInitialAutoUpdate {
log.Infof("start engine by ui, run auto-update check")
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
c.doInitialAutoUpdate = false
}
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
@@ -619,6 +570,12 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
@@ -637,7 +594,12 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {
return nil, err
}
return loginResp, nil
}
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {

View File

@@ -1,73 +0,0 @@
//go:build android
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
// It returns an empty interface list, which means ICE P2P candidates won't be
// discovered — connections will fall back to relay. Applications that need P2P
// should provide a real implementation via runOnAndroidEmbed that uses
// Android's ConnectivityManager to enumerate network interfaces.
type noopIFaceDiscover struct{}
func (noopIFaceDiscover) IFaces() (string, error) {
// Return empty JSON array — no local interfaces advertised for ICE.
// This is intentional: without Android's ConnectivityManager, we cannot
// reliably enumerate interfaces (netlink is restricted on Android 11+).
// Relay connections still work; only P2P hole-punching is disabled.
return "[]", nil
}
// noopNetworkChangeListener is a stub for embed.Client on Android.
// Network change events are ignored since the embed client manages its own
// reconnection logic via the engine's built-in retry mechanism.
type noopNetworkChangeListener struct{}
func (noopNetworkChangeListener) OnNetworkChanged(string) {
// No-op: embed.Client relies on the engine's internal reconnection
// logic rather than OS-level network change notifications.
}
func (noopNetworkChangeListener) SetInterfaceIP(string) {
// No-op: in netstack mode, the overlay IP is managed by the userspace
// network stack, not by OS-level interface configuration.
}
// noopDnsReadyListener is a stub for embed.Client on Android.
// DNS readiness notifications are not needed in netstack/embed mode
// since system DNS is disabled and DNS resolution happens externally.
type noopDnsReadyListener struct{}
func (noopDnsReadyListener) OnReady() {
// No-op: embed.Client does not need DNS readiness notifications.
// System DNS is disabled in netstack mode.
}
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
var _ dns.ReadyListener = noopDnsReadyListener{}
func init() {
// Wire up the default override so embed.Client.Start() works on Android
// with netstack mode. Provides complete no-op stubs for all mobile
// dependencies so the engine's existing Android code paths work unchanged.
// Applications that need P2P ICE or real DNS should replace this by
// setting androidRunOverride before calling Start().
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
return c.runOnAndroidEmbed(
noopIFaceDiscover{},
noopNetworkChangeListener{},
[]netip.AddrPort{},
noopDnsReadyListener{},
runningChan,
logPath,
)
}
}

View File

@@ -1,32 +0,0 @@
//go:build android
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
// so embed.Client.Start() can detect when the engine is ready.
// It provides complete MobileDependency so the engine's existing
// Android code paths work unchanged.
func (c *ConnectClient) runOnAndroidEmbed(
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
runningChan chan struct{},
logPath string,
) error {
mobileDependency := MobileDependency{
IFaceDiscover: iFaceDiscover,
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return c.run(mobileDependency, runningChan, logPath)
}

View File

@@ -25,13 +25,13 @@ import (
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updater/installer"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
const readmeContent = `Netbird debug bundle
@@ -53,8 +53,6 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states for the active profile.
service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
mutex.prof: Mutex profiling information.
goroutine.prof: Goroutine profiling information.
block.prof: Block profiling information.
@@ -221,11 +219,6 @@ const (
darwinStdoutLogPath = "/var/log/netbird.err.log"
)
// MetricsExporter is an interface for exporting metrics
type MetricsExporter interface {
Export(w io.Writer) error
}
type BundleGenerator struct {
anonymizer *anonymize.Anonymizer
@@ -236,7 +229,6 @@ type BundleGenerator struct {
logPath string
cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
anonymize bool
includeSystemInfo bool
@@ -258,7 +250,6 @@ type GeneratorDependencies struct {
LogPath string
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
ClientMetrics MetricsExporter
}
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
@@ -277,7 +268,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
logPath: deps.LogPath,
cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
anonymize: cfg.Anonymize,
includeSystemInfo: cfg.IncludeSystemInfo,
@@ -361,14 +351,6 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
}
if err := g.addServiceParams(); err != nil {
log.Errorf("failed to add service params to debug bundle: %v", err)
}
if err := g.addMetrics(); err != nil {
log.Errorf("failed to add metrics to debug bundle: %v", err)
}
if err := g.addWgShow(); err != nil {
log.Errorf("failed to add wg show output: %v", err)
}
@@ -436,10 +418,7 @@ func (g *BundleGenerator) addStatus() error {
fullStatus := g.statusRecorder.GetFullStatus()
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{
Anonymize: g.anonymize,
ProfileName: profName,
})
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
statusOutput := overview.FullDetailSummary()
statusReader := strings.NewReader(statusOutput)
@@ -494,90 +473,6 @@ func (g *BundleGenerator) addConfig() error {
return nil
}
const (
serviceParamsFile = "service.json"
serviceParamsBundle = "service_params.json"
maskedValue = "***"
envVarPrefix = "NB_"
jsonKeyManagementURL = "management_url"
jsonKeyServiceEnv = "service_env_vars"
)
var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
func (g *BundleGenerator) addServiceParams() error {
path := filepath.Join(configs.StateDir, serviceParamsFile)
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("read service params: %w", err)
}
var params map[string]any
if err := json.Unmarshal(data, &params); err != nil {
return fmt.Errorf("parse service params: %w", err)
}
if g.anonymize {
if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
}
}
g.sanitizeServiceEnvVars(params)
sanitizedData, err := json.MarshalIndent(params, "", " ")
if err != nil {
return fmt.Errorf("marshal sanitized service params: %w", err)
}
if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
return fmt.Errorf("add service params to zip: %w", err)
}
return nil
}
// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
if !ok {
return
}
sanitized := make(map[string]any, len(envVars))
for k, v := range envVars {
val, _ := v.(string)
switch {
case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
sanitized[k] = maskedValue
case g.anonymize:
sanitized[k] = g.anonymizer.AnonymizeString(val)
default:
sanitized[k] = val
}
}
params[jsonKeyServiceEnv] = sanitized
}
// isSensitiveEnvVar returns true for env var names that may contain secrets.
func isSensitiveEnvVar(key string) bool {
lower := strings.ToLower(key)
for _, s := range sensitiveEnvSubstrings {
if strings.Contains(lower, s) {
return true
}
}
return false
}
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n")
@@ -849,30 +744,6 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
return nil
}
func (g *BundleGenerator) addMetrics() error {
if g.clientMetrics == nil {
log.Debugf("skipping metrics in debug bundle: no metrics collector")
return nil
}
var buf bytes.Buffer
if err := g.clientMetrics.Export(&buf); err != nil {
return fmt.Errorf("export metrics: %w", err)
}
if buf.Len() == 0 {
log.Debugf("skipping metrics.txt in debug bundle: no metrics data")
return nil
}
if err := g.addFileToZip(&buf, "metrics.txt"); err != nil {
return fmt.Errorf("add metrics file to zip: %w", err)
}
log.Debugf("added metrics to debug bundle")
return nil
}
func (g *BundleGenerator) addLogfile() error {
if g.logPath == "" {
log.Debugf("skipping empty log file in debug bundle")

View File

@@ -1,12 +1,8 @@
package debug
import (
"archive/zip"
"bytes"
"encoding/json"
"net"
"os"
"path/filepath"
"strings"
"testing"
@@ -14,7 +10,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -425,226 +420,6 @@ func TestAnonymizeNetworkMap(t *testing.T) {
}
}
func TestIsSensitiveEnvVar(t *testing.T) {
tests := []struct {
key string
sensitive bool
}{
{"NB_SETUP_KEY", true},
{"NB_API_TOKEN", true},
{"NB_CLIENT_SECRET", true},
{"NB_PASSWORD", true},
{"NB_CREDENTIAL", true},
{"NB_LOG_LEVEL", false},
{"NB_MANAGEMENT_URL", false},
{"NB_HOSTNAME", false},
{"HOME", false},
{"PATH", false},
}
for _, tt := range tests {
t.Run(tt.key, func(t *testing.T) {
assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
})
}
}
func TestSanitizeServiceEnvVars(t *testing.T) {
tests := []struct {
name string
anonymize bool
input map[string]any
check func(t *testing.T, params map[string]any)
}{
{
name: "no env vars key",
anonymize: false,
input: map[string]any{"management_url": "https://mgmt.example.com"},
check: func(t *testing.T, params map[string]any) {
t.Helper()
assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
_, ok := params[jsonKeyServiceEnv]
assert.False(t, ok, "service_env_vars should not be added")
},
},
{
name: "non-NB vars are masked",
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"HOME": "/root",
"PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
},
},
{
name: "sensitive NB vars are masked",
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
},
},
{
name: "safe NB vars anonymized when anonymize is true",
anonymize: true,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
"NB_LOG_LEVEL": "debug",
"NB_SETUP_KEY": "secret",
"SOME_OTHER": "val",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
// Safe NB_ values should be anonymized (not the original, not masked)
mgmtVal := env["NB_MANAGEMENT_URL"].(string)
assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
logVal := env["NB_LOG_LEVEL"].(string)
assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
// Sensitive and non-NB_ still masked
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
assert.Equal(t, maskedValue, env["SOME_OTHER"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
g := &BundleGenerator{
anonymize: tt.anonymize,
anonymizer: anonymizer,
}
g.sanitizeServiceEnvVars(tt.input)
tt.check(t, tt.input)
})
}
}
func TestAddServiceParams(t *testing.T) {
t.Run("missing service.json returns nil", func(t *testing.T) {
g := &BundleGenerator{
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
}
origStateDir := configs.StateDir
configs.StateDir = t.TempDir()
t.Cleanup(func() { configs.StateDir = origStateDir })
err := g.addServiceParams()
assert.NoError(t, err)
})
t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
dir := t.TempDir()
origStateDir := configs.StateDir
configs.StateDir = dir
t.Cleanup(func() { configs.StateDir = origStateDir })
input := map[string]any{
jsonKeyManagementURL: "https://api.example.com:443",
jsonKeyServiceEnv: map[string]any{
"NB_LOG_LEVEL": "trace",
},
}
data, err := json.Marshal(input)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
g := &BundleGenerator{
anonymize: true,
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
archive: zw,
}
require.NoError(t, g.addServiceParams())
require.NoError(t, zw.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
require.Len(t, zr.File, 1)
assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
rc, err := zr.File[0].Open()
require.NoError(t, err)
defer rc.Close()
var result map[string]any
require.NoError(t, json.NewDecoder(rc).Decode(&result))
mgmt := result[jsonKeyManagementURL].(string)
assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
assert.NotEmpty(t, mgmt)
env := result[jsonKeyServiceEnv].(map[string]any)
assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
})
t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
dir := t.TempDir()
origStateDir := configs.StateDir
configs.StateDir = dir
t.Cleanup(func() { configs.StateDir = origStateDir })
input := map[string]any{
jsonKeyManagementURL: "https://api.example.com:443",
}
data, err := json.Marshal(input)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
g := &BundleGenerator{
anonymize: false,
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
archive: zw,
}
require.NoError(t, g.addServiceParams())
require.NoError(t, zw.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
rc, err := zr.File[0].Open()
require.NoError(t, err)
defer rc.Close()
var result map[string]any
require.NoError(t, json.NewDecoder(rc).Decode(&result))
assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
})
}
// Helper function to check if IP is in CGNAT range
func isInCGNATRange(ip net.IP) bool {
cgnat := net.IPNet{

View File

@@ -73,9 +73,6 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
return nil
}
w.response = m
if m.MsgHdr.Truncated {
w.SetMeta("truncated", "true")
}
return w.ResponseWriter.WriteMsg(m)
}
@@ -198,14 +195,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
startTime := time.Now()
requestID := resutil.GenerateRequestID()
fields := log.Fields{
logger := log.WithFields(log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
})
question := r.Question[0]
qname := strings.ToLower(question.Name)
@@ -268,9 +261,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
cw.response.Len(), meta, time.Since(startTime))
meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {

View File

@@ -77,7 +77,7 @@ func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability(context.Context) {}
func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {

View File

@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
})
}
// TestLocalResolver_Stop tests cleanup on GracefullyStop
// TestLocalResolver_Stop tests cleanup on Stop
func TestLocalResolver_Stop(t *testing.T) {
t.Run("GracefullyStop clears all state", func(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
resolver.Stop()
})
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()
lookupStarted := make(chan struct{})

View File

@@ -85,16 +85,6 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
// Mock implementation - no-op
}
// SetFirewall mock implementation of SetFirewall from Server interface
func (m *MockServer) SetFirewall(Firewall) {
// Mock implementation - no-op
}
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op

View File

@@ -104,23 +104,3 @@ func (r *responseWriter) TsigTimersOnly(bool) {
// After a call to Hijack(), the DNS package will not do anything with the connection.
func (r *responseWriter) Hijack() {
}
// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging.
func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr {
var srcIP net.IP
if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil {
srcIP = ipv4.(*layers.IPv4).SrcIP
} else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil {
srcIP = ipv6.(*layers.IPv6).SrcIP
}
var srcPort int
if udp := packet.Layer(layers.LayerTypeUDP); udp != nil {
srcPort = int(udp.(*layers.UDP).SrcPort)
}
if srcIP == nil {
return nil
}
return &net.UDPAddr{IP: srcIP, Port: srcPort}
}

View File

@@ -57,8 +57,6 @@ type Server interface {
ProbeAvailability()
UpdateServerConfig(domains dnsconfig.ServerDomains) error
PopulateManagementDomain(mgmtURL *url.URL) error
SetRouteChecker(func(netip.Addr) bool)
SetFirewall(Firewall)
}
type nsGroupsByDomain struct {
@@ -106,17 +104,12 @@ type DefaultServer struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
routeMatch func(netip.Addr) bool
probeMu sync.Mutex
probeCancel context.CancelFunc
probeWg sync.WaitGroup
}
type handlerWithStop interface {
dns.Handler
Stop()
ProbeAvailability(context.Context)
ProbeAvailability()
ID() types.HandlerID
}
@@ -152,7 +145,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
if config.WgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(config.WgInterface)
} else {
dnsService = newServiceViaListener(config.WgInterface, addrPort, nil)
dnsService = newServiceViaListener(config.WgInterface, addrPort)
}
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
@@ -187,16 +180,11 @@ func NewDefaultServerIos(
ctx context.Context,
wgInterface WGIface,
iosDnsManager IosDnsManager,
hostsDnsList []netip.AddrPort,
statusRecorder *peer.Status,
disableSys bool,
) *DefaultServer {
log.Debugf("iOS host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.iosDnsManager = iosDnsManager
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true
ds.addHostRootZone()
return ds
}
@@ -237,14 +225,6 @@ func newDefaultServer(
return defaultServer
}
// SetRouteChecker sets the function used by upstream resolvers to determine
// whether an IP is routed through the tunnel.
func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) {
s.mux.Lock()
defer s.mux.Unlock()
s.routeMatch = f
}
// RegisterHandler registers a handler for the given domains with the given priority.
// Any previously registered handler for the same domain and priority will be replaced.
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
@@ -380,26 +360,9 @@ func (s *DefaultServer) DnsIP() netip.Addr {
return s.service.RuntimeIP()
}
// SetFirewall sets the firewall used for DNS port DNAT rules.
// This must be called before Initialize when using the listener-based service,
// because the firewall is typically not available at construction time.
func (s *DefaultServer) SetFirewall(fw Firewall) {
if svc, ok := s.service.(*serviceViaListener); ok {
svc.listenerFlagLock.Lock()
svc.firewall = fw
svc.listenerFlagLock.Unlock()
}
}
// Stop stops the server
func (s *DefaultServer) Stop() {
s.probeMu.Lock()
if s.probeCancel != nil {
s.probeCancel()
}
s.ctxCancel()
s.probeMu.Unlock()
s.probeWg.Wait()
s.shutdownWg.Wait()
s.mux.Lock()
@@ -412,12 +375,8 @@ func (s *DefaultServer) Stop() {
maps.Clear(s.extraDomains)
}
func (s *DefaultServer) disableDNS() (retErr error) {
defer func() {
if err := s.service.Stop(); err != nil {
retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err))
}
}()
func (s *DefaultServer) disableDNS() error {
defer s.service.Stop()
if s.isUsingNoopHostManager() {
return nil
@@ -520,8 +479,7 @@ func (s *DefaultServer) SearchDomains() []string {
}
// ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds.
// If a previous probe is still running, it will be cancelled before starting a new one.
// and deactivates the group if no server responds
func (s *DefaultServer) ProbeAvailability() {
if val := os.Getenv(envSkipDNSProbe); val != "" {
skipProbe, err := strconv.ParseBool(val)
@@ -534,52 +492,15 @@ func (s *DefaultServer) ProbeAvailability() {
}
}
s.probeMu.Lock()
// don't start probes on a stopped server
if s.ctx.Err() != nil {
s.probeMu.Unlock()
return
}
// cancel any running probe
if s.probeCancel != nil {
s.probeCancel()
s.probeCancel = nil
}
// wait for the previous probe goroutines to finish while holding
// the mutex so no other caller can start a new probe concurrently
s.probeWg.Wait()
// start a new probe
probeCtx, probeCancel := context.WithCancel(s.ctx)
s.probeCancel = probeCancel
s.probeWg.Add(1)
defer s.probeWg.Done()
// Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers.
s.mux.Lock()
handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap))
for _, mux := range s.dnsMuxMap {
handlers = append(handlers, mux.handler)
}
s.mux.Unlock()
var wg sync.WaitGroup
for _, handler := range handlers {
for _, mux := range s.dnsMuxMap {
wg.Add(1)
go func(h handlerWithStop) {
go func(mux handlerWithStop) {
defer wg.Done()
h.ProbeAvailability(probeCtx)
}(handler)
mux.ProbeAvailability()
}(mux.handler)
}
s.probeMu.Unlock()
wg.Wait()
probeCancel()
}
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
@@ -774,7 +695,6 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
log.Errorf("failed to create upstream resolver for original nameservers: %v", err)
return
}
handler.routeMatch = s.routeMatch
for _, ns := range originalNameservers {
if ns == config.ServerIP {
@@ -884,7 +804,6 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
if err != nil {
return nil, fmt.Errorf("create upstream resolver: %v", err)
}
handler.routeMatch = s.routeMatch
for _, ns := range nsGroup.NameServers {
if ns.NSType != nbdns.UDPNameServerType {
@@ -1069,7 +988,6 @@ func (s *DefaultServer) addHostRootZone() {
log.Errorf("unable to create a new upstream resolver, error: %v", err)
return
}
handler.routeMatch = s.routeMatch
handler.upstreamServers = maps.Keys(hostDNSServers)
handler.deactivate = func(error) {}

View File

@@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
@@ -1065,13 +1065,13 @@ type mockHandler struct {
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) Stop() {}
func (m *mockHandler) ProbeAvailability(context.Context) {}
func (m *mockHandler) ProbeAvailability() {}
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
type mockService struct{}
func (m *mockService) Listen() error { return nil }
func (m *mockService) Stop() error { return nil }
func (m *mockService) Stop() {}
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
func (m *mockService) RuntimePort() int { return 53 }
func (m *mockService) RegisterMux(string, dns.Handler) {}

View File

@@ -4,25 +4,15 @@ import (
"net/netip"
"github.com/miekg/dns"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
DefaultPort = 53
)
// Firewall provides DNAT capabilities for DNS port redirection.
// This is used when the DNS server cannot bind port 53 directly
// and needs firewall rules to redirect traffic.
type Firewall interface {
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
}
type service interface {
Listen() error
Stop() error
Stop()
RegisterMux(domain string, handler dns.Handler)
DeregisterMux(key string)
RuntimePort() int

View File

@@ -6,17 +6,12 @@ import (
"net"
"net/netip"
"runtime"
"strconv"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
@@ -35,33 +30,25 @@ type serviceViaListener struct {
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
tcpServer *dns.Server
listenIP netip.Addr
listenPort uint16
listenerIsRunning bool
listenerFlagLock sync.Mutex
ebpfService ebpfMgr.Manager
firewall Firewall
tcpDNATConfigured bool
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener {
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
mux := dns.NewServeMux()
s := &serviceViaListener{
wgInterface: wgIface,
dnsMux: mux,
customAddr: customAddr,
firewall: fw,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
tcpServer: &dns.Server{
Net: "tcp",
Handler: mux,
},
}
return s
@@ -82,86 +69,43 @@ func (s *serviceViaListener) Listen() error {
return fmt.Errorf("eval listen address: %w", err)
}
s.listenIP = s.listenIP.Unmap()
addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
s.server.Addr = addr
s.tcpServer.Addr = addr
log.Debugf("starting dns on %s (UDP + TCP)", addr)
s.listenerIsRunning = true
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
if err := s.server.ListenAndServe(); err != nil {
log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err)
}
s.setListenerStatus(true)
defer s.setListenerStatus(false)
s.listenerFlagLock.Lock()
unexpected := s.listenerIsRunning
s.listenerIsRunning = false
s.listenerFlagLock.Unlock()
if unexpected {
if err := s.tcpServer.Shutdown(); err != nil {
log.Debugf("failed to shutdown DNS TCP server: %v", err)
}
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
}
}()
go func() {
if err := s.tcpServer.ListenAndServe(); err != nil {
log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err)
}
}()
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
// a DNAT rule because eBPF only handles UDP.
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err)
} else {
s.tcpDNATConfigured = true
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
}
}
return nil
}
func (s *serviceViaListener) Stop() error {
func (s *serviceViaListener) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return nil
return
}
s.listenerIsRunning = false
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var merr *multierror.Error
if err := s.server.ShutdownContext(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err))
}
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err))
}
if s.tcpDNATConfigured && s.firewall != nil {
if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
}
s.tcpDNATConfigured = false
err := s.server.ShutdownContext(ctx)
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
}
if s.ebpfService != nil {
if err := s.ebpfService.FreeDNSFwd(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err))
err = s.ebpfService.FreeDNSFwd()
if err != nil {
log.Errorf("stopping traffic forwarder returned an error: %v", err)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
@@ -188,6 +132,12 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
return s.listenIP
}
func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
s.listenerIsRunning = running
}
// evalListenAddress figure out the listen address for the DNS server
// first check the 53 port availability on WG interface or lo, if not success
@@ -236,28 +186,18 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
}
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
addrPort := netip.AddrPortFrom(ip, uint16(port))
udpAddr := net.UDPAddrFromAddrPort(addrPort)
udpLn, err := net.ListenUDP("udp", udpAddr)
addrString := fmt.Sprintf("%s:%d", ip, port)
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err != nil {
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
return false
}
if err := udpLn.Close(); err != nil {
log.Debugf("close UDP probe listener: %s", err)
}
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
err = probeListener.Close()
if err != nil {
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
return false
log.Errorf("got an error closing the probe listener, error: %s", err)
}
if err := tcpLn.Close(); err != nil {
log.Debugf("close TCP probe listener: %s", err)
}
return true
}

View File

@@ -1,86 +0,0 @@
package dns
import (
"fmt"
"net"
"net/netip"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("192.0.2.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
// Create a service using a custom address to avoid needing root
svc := newServiceViaListener(nil, nil, nil)
svc.dnsMux.Handle(".", handler)
// Bind both transports up front to avoid TOCTOU races.
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
t.Skip("cannot bind to 127.0.0.153, skipping")
}
port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port)
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
udpConn.Close()
t.Skip("cannot bind TCP on same port, skipping")
}
addr := fmt.Sprintf("%s:%d", customIP, port)
svc.server.PacketConn = udpConn
svc.tcpServer.Listener = tcpLn
svc.listenIP = customIP
svc.listenPort = port
go func() {
if err := svc.server.ActivateAndServe(); err != nil {
t.Logf("udp server: %v", err)
}
}()
go func() {
if err := svc.tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
svc.listenerIsRunning = true
defer func() {
require.NoError(t, svc.Stop())
}()
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
// Test UDP query
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
udpResp, _, err := udpClient.Exchange(q, addr)
require.NoError(t, err, "UDP query should succeed")
require.NotNil(t, udpResp)
require.NotEmpty(t, udpResp.Answer)
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
// Test TCP query
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
tcpResp, _, err := tcpClient.Exchange(q, addr)
require.NoError(t, err, "TCP query should succeed")
require.NotNil(t, tcpResp)
require.NotEmpty(t, tcpResp.Answer)
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
}

View File

@@ -1,7 +1,6 @@
package dns
import (
"errors"
"fmt"
"net/netip"
"sync"
@@ -11,7 +10,6 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/client/net"
)
@@ -20,8 +18,7 @@ type ServiceViaMemory struct {
dnsMux *dns.ServeMux
runtimeIP netip.Addr
runtimePort int
tcpDNS *tcpDNSServer
tcpHookSet bool
udpFilterHookID string
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
@@ -31,13 +28,14 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
return &ServiceViaMemory{
s := &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: lastIP,
runtimePort: DefaultPort,
}
return s
}
func (s *ServiceViaMemory) Listen() error {
@@ -48,8 +46,10 @@ func (s *ServiceViaMemory) Listen() error {
return nil
}
if err := s.filterDNSTraffic(); err != nil {
return fmt.Errorf("filter dns traffic: %w", err)
var err error
s.udpFilterHookID, err = s.filterDNSTraffic()
if err != nil {
return fmt.Errorf("filter dns traffice: %w", err)
}
s.listenerIsRunning = true
@@ -57,29 +57,19 @@ func (s *ServiceViaMemory) Listen() error {
return nil
}
func (s *ServiceViaMemory) Stop() error {
func (s *ServiceViaMemory) Stop() {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return nil
return
}
filter := s.wgInterface.GetFilter()
if filter != nil {
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
if s.tcpHookSet {
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
}
}
if s.tcpDNS != nil {
s.tcpDNS.Stop()
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
}
s.listenerIsRunning = false
return nil
}
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
@@ -98,18 +88,10 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
return s.runtimeIP
}
func (s *ServiceViaMemory) filterDNSTraffic() error {
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
filter := s.wgInterface.GetFilter()
if filter == nil {
return errors.New("DNS filter not initialized")
}
// Create TCP DNS server lazily here since the device may not exist at construction time.
if s.tcpDNS == nil {
if dev := s.wgInterface.GetDevice(); dev != nil {
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
}
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
}
firstLayerDecoder := layers.LayerTypeIPv4
@@ -118,16 +100,12 @@ func (s *ServiceViaMemory) filterDNSTraffic() error {
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
if udpLayer == nil {
return true
}
udp, ok := udpLayer.(*layers.UDP)
if !ok {
return true
}
udp := udpLayer.(*layers.UDP)
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
@@ -135,30 +113,13 @@ func (s *ServiceViaMemory) filterDNSTraffic() error {
return true
}
dev := s.wgInterface.GetDevice()
if dev == nil {
return true
}
writer := &responseWriter{
remote: remoteAddrFromPacket(packet),
writer := responseWriter{
packet: packet,
device: dev.Device,
device: s.wgInterface.GetDevice().Device,
}
go s.dnsMux.ServeDNS(writer, msg)
go s.dnsMux.ServeDNS(&writer, msg)
return true
}
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
if s.tcpDNS != nil {
tcpHook := func(packetData []byte) bool {
s.tcpDNS.InjectPacket(packetData)
return true
}
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
s.tcpHookSet = true
}
return nil
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
}

View File

@@ -1,444 +0,0 @@
package dns
import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
dnsTCPReceiveWindow = 8192
dnsTCPMaxInFlight = 16
dnsTCPIdleTimeout = 30 * time.Second
dnsTCPReadTimeout = 5 * time.Second
)
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
// It is started lazily when a truncated DNS response is detected and shuts down
// after a period of inactivity to conserve resources.
type tcpDNSServer struct {
mu sync.Mutex
s *stack.Stack
ep *dnsEndpoint
mux *dns.ServeMux
tunDev tun.Device
ip netip.Addr
port uint16
mtu uint16
running bool
closed bool
timerID uint64
timer *time.Timer
}
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
return &tcpDNSServer{
mux: mux,
tunDev: tunDev,
ip: ip,
port: port,
mtu: mtu,
}
}
// InjectPacket ensures the stack is running and delivers a raw IP packet into
// the gvisor stack for TCP processing. Combining both operations under a single
// lock prevents a race where the idle timer could stop the stack between
// start and delivery.
func (t *tcpDNSServer) InjectPacket(payload []byte) {
t.mu.Lock()
defer t.mu.Unlock()
if t.closed {
return
}
if !t.running {
if err := t.startLocked(); err != nil {
log.Errorf("failed to start TCP DNS stack: %v", err)
return
}
t.running = true
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
}
t.resetTimerLocked()
ep := t.ep
if ep == nil || ep.dispatcher == nil {
return
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
// Stop tears down the gvisor stack and releases resources permanently.
// After Stop, InjectPacket becomes a no-op.
func (t *tcpDNSServer) Stop() {
t.mu.Lock()
defer t.mu.Unlock()
t.stopLocked()
t.closed = true
}
func (t *tcpDNSServer) startLocked() error {
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
HandleLocal: false,
})
nicID := tcpip.NICID(1)
ep := &dnsEndpoint{
tunDev: t.tunDev,
}
ep.mtu.Store(uint32(t.mtu))
if err := s.CreateNIC(nicID, ep); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("create NIC: %v", err)
}
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
PrefixLen: 32,
},
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("add protocol address: %s", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("set promiscuous mode: %s", err)
}
if err := s.SetSpoofing(nicID, true); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("set spoofing: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
)
if err != nil {
s.Close()
s.Wait()
return fmt.Errorf("create default subnet: %w", err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: defaultSubnet, NIC: nicID},
})
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
t.handleTCPDNS(r)
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
t.s = s
t.ep = ep
return nil
}
func (t *tcpDNSServer) stopLocked() {
if !t.running {
return
}
if t.timer != nil {
t.timer.Stop()
t.timer = nil
}
if t.s != nil {
t.s.Close()
t.s.Wait()
t.s = nil
}
t.ep = nil
t.running = false
log.Debugf("TCP DNS stack stopped")
}
func (t *tcpDNSServer) resetTimerLocked() {
if t.timer != nil {
t.timer.Stop()
}
t.timerID++
id := t.timerID
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
t.mu.Lock()
defer t.mu.Unlock()
// Only stop if this timer is still the active one.
// A racing InjectPacket may have replaced it.
if t.timerID != id {
return
}
t.stopLocked()
})
}
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
id := r.ID()
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
r.Complete(true)
return
}
r.Complete(false)
conn := gonet.NewTCPConn(&wq, ep)
defer func() {
if err := conn.Close(); err != nil {
log.Tracef("TCP DNS: close conn: %v", err)
}
}()
// Reset idle timer on activity
t.mu.Lock()
t.resetTimerLocked()
t.mu.Unlock()
localAddr := &net.TCPAddr{
IP: id.LocalAddress.AsSlice(),
Port: int(id.LocalPort),
}
remoteAddr := &net.TCPAddr{
IP: id.RemoteAddress.AsSlice(),
Port: int(id.RemotePort),
}
for {
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
break
}
msg, err := readTCPDNSMessage(conn)
if err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
}
break
}
writer := &tcpResponseWriter{
conn: conn,
localAddr: localAddr,
remoteAddr: remoteAddr,
}
t.mux.ServeDNS(writer, msg)
}
}
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
type dnsEndpoint struct {
dispatcher stack.NetworkDispatcher
tunDev tun.Device
mtu atomic.Uint32
}
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
func (e *dnsEndpoint) Wait() { /* no async work */ }
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
const tunPacketOffset = 40
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
raw := data.AsSlice()
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
buf = append(buf, raw...)
data.Release()
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
continue
}
written++
}
return written, nil
}
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
type tcpResponseWriter struct {
conn *gonet.TCPConn
localAddr net.Addr
remoteAddr net.Addr
}
func (w *tcpResponseWriter) LocalAddr() net.Addr {
return w.localAddr
}
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
return w.remoteAddr
}
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
data, err := msg.Pack()
if err != nil {
return fmt.Errorf("pack: %w", err)
}
// DNS TCP: 2-byte length prefix + message
buf := make([]byte, 2+len(data))
buf[0] = byte(len(data) >> 8)
buf[1] = byte(len(data))
copy(buf[2:], data)
if _, err = w.conn.Write(buf); err != nil {
return err
}
return nil
}
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
buf := make([]byte, 2+len(data))
buf[0] = byte(len(data) >> 8)
buf[1] = byte(len(data))
copy(buf[2:], data)
if _, err := w.conn.Write(buf); err != nil {
return 0, err
}
return len(data), nil
}
func (w *tcpResponseWriter) Close() error {
return w.conn.Close()
}
func (w *tcpResponseWriter) TsigStatus() error { return nil }
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
// DNS over TCP uses a 2-byte length prefix
lenBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return nil, fmt.Errorf("read length: %w", err)
}
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
if msgLen == 0 || msgLen > 65535 {
return nil, fmt.Errorf("invalid message length: %d", msgLen)
}
msgBuf := make([]byte, msgLen)
if _, err := io.ReadFull(conn, msgBuf); err != nil {
return nil, fmt.Errorf("read message: %w", err)
}
msg := new(dns.Msg)
if err := msg.Unpack(msgBuf); err != nil {
return nil, fmt.Errorf("unpack: %w", err)
}
return msg, nil
}
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
// Supports both IPv4 and IPv6.
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
if len(pkt) == 0 {
return netip.AddrPort{}
}
srcIP, transportOffset := srcIPFromPacket(pkt)
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
return netip.AddrPort{}
}
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
}
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
switch header.IPVersion(pkt) {
case 4:
return srcIPv4(pkt)
case 6:
return srcIPv6(pkt)
default:
return netip.Addr{}, 0
}
}
func srcIPv4(pkt []byte) (netip.Addr, int) {
if len(pkt) < header.IPv4MinimumSize {
return netip.Addr{}, 0
}
hdr := header.IPv4(pkt)
src := hdr.SourceAddress()
ip, ok := netip.AddrFromSlice(src.AsSlice())
if !ok {
return netip.Addr{}, 0
}
return ip, int(hdr.HeaderLength())
}
func srcIPv6(pkt []byte) (netip.Addr, int) {
if len(pkt) < header.IPv6MinimumSize {
return netip.Addr{}, 0
}
hdr := header.IPv6(pkt)
src := hdr.SourceAddress()
ip, ok := netip.AddrFromSlice(src.AsSlice())
if !ok {
return netip.Addr{}, 0
}
return ip, header.IPv6MinimumSize
}

View File

@@ -41,61 +41,10 @@ const (
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
// payload from the tunnel MTU.
ipUDPHeaderSize = 60 + 8
)
const testRecord = "com."
const (
protoUDP = "udp"
protoTCP = "tcp"
)
type dnsProtocolKey struct{}
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
return context.WithValue(ctx, dnsProtocolKey{}, network)
}
// dnsProtocolFromContext retrieves the inbound DNS protocol from context.
func dnsProtocolFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
return v
}
return ""
}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
r := &upstreamProtocolResult{}
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
}
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
func setUpstreamProtocol(ctx context.Context, protocol string) {
if ctx == nil {
return
}
if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil {
r.protocol = protocol
}
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
@@ -116,12 +65,10 @@ type upstreamResolverBase struct {
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
wg sync.WaitGroup
deactivate func(error)
reactivate func()
statusRecorder *peer.Status
routeMatch func(netip.Addr) bool
}
type upstreamFailure struct {
@@ -168,11 +115,6 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
func (u *upstreamResolverBase) Stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
u.cancel()
u.mutex.Lock()
u.wg.Wait()
u.mutex.Unlock()
}
// ServeDNS handles a DNS request
@@ -189,16 +131,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
// Propagate inbound protocol so upstream exchange can use TCP directly
// when the request came in over TCP.
ctx := u.ctx
if addr := w.RemoteAddr(); addr != nil {
network := addr.Network()
ctx = contextWithDNSProtocol(ctx, network)
resutil.SetMeta(w, "protocol", network)
}
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
ok, failures := u.tryUpstreamServers(w, r, logger)
if len(failures) > 0 {
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
}
@@ -213,7 +146,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
}
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
@@ -228,7 +161,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.Res
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
@@ -238,17 +171,15 @@ func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.Res
}
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
@@ -265,7 +196,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
return nil
}
@@ -282,13 +213,10 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
resutil.SetMeta(w, "upstream", upstream.String())
if upstreamProto != nil && upstreamProto.protocol != "" {
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
}
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
@@ -332,10 +260,16 @@ func formatFailures(failures []upstreamFailure) string {
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
func (u *upstreamResolverBase) ProbeAvailability() {
u.mutex.Lock()
defer u.mutex.Unlock()
select {
case <-u.ctx.Done():
return
default:
}
// avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 {
return
@@ -345,39 +279,31 @@ func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
var mu sync.Mutex
var wg sync.WaitGroup
var errs *multierror.Error
var errors *multierror.Error
for _, upstream := range u.upstreamServers {
upstream := upstream
wg.Add(1)
go func(upstream netip.AddrPort) {
go func() {
defer wg.Done()
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
err := u.testNameserver(upstream, 500*time.Millisecond)
if err != nil {
mu.Lock()
errs = multierror.Append(errs, err)
mu.Unlock()
errors = multierror.Append(errors, err)
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
return
}
mu.Lock()
defer mu.Unlock()
success = true
mu.Unlock()
}(upstream)
}()
}
wg.Wait()
select {
case <-ctx.Done():
return
case <-u.ctx.Done():
return
default:
}
// didn't find a working upstream server, let's disable and try later
if !success {
u.disable(errs.ErrorOrNil())
u.disable(errors.ErrorOrNil())
if u.statusRecorder == nil {
return
@@ -413,7 +339,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
}
for _, upstream := range u.upstreamServers {
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
if err := u.testNameserver(upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err)
} else {
// at least one upstream server is available, stop probing
@@ -438,9 +364,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.successCount.Add(1)
u.reactivate()
u.mutex.Lock()
u.disabled = false
u.mutex.Unlock()
}
// isTimeout returns true if the given error is a network timeout error.
@@ -463,11 +387,7 @@ func (u *upstreamResolverBase) disable(err error) {
u.successCount.Store(0)
u.deactivate(err)
u.disabled = true
u.wg.Add(1)
go func() {
defer u.wg.Done()
u.waitUntilResponse()
}()
go u.waitUntilResponse()
}
func (u *upstreamResolverBase) upstreamServersString() string {
@@ -478,57 +398,23 @@ func (u *upstreamResolverBase) upstreamServersString() string {
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
if externalCtx != nil {
stop2 := context.AfterFunc(externalCtx, cancel)
defer stop2()
}
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
_, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
return err
}
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
func clientUDPMaxSize(r *dns.Msg) int {
if opt := r.IsEdns0(); opt != nil {
return int(opt.UDPSize())
}
return dns.MinMsgSize
}
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the inbound request came over TCP (via context), it skips the UDP attempt.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// If the request came in over TCP, go straight to TCP upstream.
if dnsProtocolFromContext(ctx) == protoTCP {
tcpClient := *client
tcpClient.Net = protoTCP
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
setUpstreamProtocol(ctx, protoTCP)
return rm, t, nil
}
clientMaxSize := clientUDPMaxSize(r)
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
// response larger than our read buffer.
// Note: the query could be sent out on an interface that is not ours,
// but higher MTU settings could break truncation handling.
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
client.UDPSize = maxUDPPayload
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
opt.SetUDPSize(maxUDPPayload)
}
// MTU - ip + udp headers
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
client.UDPSize = uint16(currentMTU - (60 + 8))
var (
rm *dns.Msg
@@ -547,32 +433,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
}
if rm == nil || !rm.MsgHdr.Truncated {
setUpstreamProtocol(ctx, protoUDP)
return rm, t, nil
}
// TODO: if the upstream's truncated UDP response already contains more
// data than the client's buffer, we could truncate locally and skip
// the TCP retry.
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
tcpClient := *client
tcpClient.Net = protoTCP
client.Net = "tcp"
if ctx == nil {
rm, t, err = tcpClient.Exchange(r, upstream)
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
setUpstreamProtocol(ctx, protoTCP)
if rm.Len() > clientMaxSize {
rm.Truncate(clientMaxSize)
}
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
return rm, t, nil
}
@@ -580,46 +459,18 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
// If request came in over TCP, go straight to TCP upstream
if dnsProtocolFromContext(ctx) == protoTCP {
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
if err != nil {
return nil, err
}
setUpstreamProtocol(ctx, protoTCP)
return rm, nil
}
clientMaxSize := clientUDPMaxSize(r)
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
// response larger than what we can read over UDP.
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
opt.SetUDPSize(maxUDPPayload)
}
reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP)
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
if err != nil {
return nil, err
}
// If response is truncated, retry with TCP
if reply != nil && reply.MsgHdr.Truncated {
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
if err != nil {
return nil, err
}
setUpstreamProtocol(ctx, protoTCP)
if rm.Len() > clientMaxSize {
rm.Truncate(clientMaxSize)
}
return rm, nil
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
}
setUpstreamProtocol(ctx, protoUDP)
return reply, nil
}
@@ -640,7 +491,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
}
}
dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)}
dnsConn := &dns.Conn{Conn: conn}
if err := dnsConn.WriteMsg(r); err != nil {
return nil, fmt.Errorf("write %s message: %w", network, err)

View File

@@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
upstreamExchangeClient := &dns.Client{
Timeout: ClientTimeout,
}
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
}
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
@@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
Timeout: timeout,
}
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
}
func (u *upstreamResolver) isLocalResolver(upstream string) bool {

View File

@@ -65,13 +65,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
} else {
upstreamIP = upstreamIP.Unmap()
}
needsPrivate := u.lNet.Contains(upstreamIP) ||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
if needsPrivate {
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil {
return nil, 0, fmt.Errorf("create private client: %s", err)
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
}
}

View File

@@ -188,7 +188,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivated = true
}
resolver.ProbeAvailability(context.TODO())
resolver.ProbeAvailability()
if !failed {
t.Errorf("expected that resolving was deactivated")
@@ -475,298 +475,3 @@ func TestFormatFailures(t *testing.T) {
})
}
}
func TestDNSProtocolContext(t *testing.T) {
t.Run("roundtrip udp", func(t *testing.T) {
ctx := contextWithDNSProtocol(context.Background(), protoUDP)
assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx))
})
t.Run("roundtrip tcp", func(t *testing.T) {
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx))
})
t.Run("missing returns empty", func(t *testing.T) {
assert.Equal(t, "", dnsProtocolFromContext(context.Background()))
})
}
func TestExchangeWithFallback_TCPContext(t *testing.T) {
// Start a local DNS server that responds on TCP only
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpServer := &dns.Server{
Addr: "127.0.0.1:0",
Net: "tcp",
Handler: tcpHandler,
}
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tcpServer.Listener = tcpLn
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = tcpServer.Shutdown()
}()
upstream := tcpLn.Addr().String()
// With TCP context, should connect directly via TCP without trying UDP
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
}
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
// UDP handler returns a truncated response to trigger TCP retry.
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Truncated = true
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
// TCP handler returns the full answer.
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.3"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{
PacketConn: udpPC,
Net: "udp",
Handler: udpHandler,
}
tcpLn, err := net.Listen("tcp", addr)
require.NoError(t, err)
tcpServer := &dns.Server{
Listener: tcpLn,
Net: "tcp",
Handler: tcpHandler,
}
go func() {
if err := udpServer.ActivateAndServe(); err != nil {
t.Logf("udp server: %v", err)
}
}()
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = udpServer.Shutdown()
_ = tcpServer.Shutdown()
}()
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err, "should fall back to TCP after truncated UDP response")
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer")
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
assert.False(t, rm.Truncated, "TCP response should not be truncated")
}
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
// Start only a TCP server (no UDP). With TCP context it should succeed.
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.2"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tcpServer := &dns.Server{
Listener: tcpLn,
Net: "tcp",
Handler: tcpHandler,
}
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = tcpServer.Shutdown()
}()
upstream := tcpLn.Addr().String()
// TCP context: should skip UDP entirely and go directly to TCP
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
// Without TCP context, trying to reach a TCP-only server via UDP should fail
ctx2 := context.Background()
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
assert.Error(t, err, "should fail when no UDP server and no TCP context")
}
func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
// Verify that a client EDNS0 larger than our MTU-derived limit gets
// capped in the outgoing request so the upstream doesn't send a
// response larger than our read buffer.
var receivedUDPSize uint16
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
if opt := r.IsEdns0(); opt != nil {
receivedUDPSize = opt.UDPSize()
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
go func() { _ = udpServer.ActivateAndServe() }()
t.Cleanup(func() { _ = udpServer.Shutdown() })
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
r.SetEdns0(4096, false)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err)
require.NotNil(t, rm)
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
assert.Equal(t, expectedMax, receivedUDPSize,
"upstream should see capped EDNS0, not the client's 4096")
}
func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
// When the client advertises a large EDNS0 (4096) and the upstream
// truncates, the TCP response should NOT be truncated since the full
// answer fits within the client's original buffer.
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Truncated = true
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
// Add enough records to exceed MTU but fit within 4096
for i := range 20 {
m.Answer = append(m.Answer, &dns.TXT{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)},
})
}
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
tcpLn, err := net.Listen("tcp", addr)
require.NoError(t, err)
tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler}
go func() { _ = udpServer.ActivateAndServe() }()
go func() { _ = tcpServer.ActivateAndServe() }()
t.Cleanup(func() {
_ = udpServer.Shutdown()
_ = tcpServer.Shutdown()
})
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
// Client with large buffer: should get all records without truncation
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
r.SetEdns0(4096, false)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err)
require.NotNil(t, rm)
assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records")
assert.False(t, rm.Truncated, "response should not be truncated for large buffer client")
// Client with small buffer: should get truncated response
r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
r2.SetEdns0(512, false)
rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr)
require.NoError(t, err)
require.NotNil(t, rm2)
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
}

View File

@@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime))
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
@@ -263,28 +263,20 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
fields := log.Fields{
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
})
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
fields := log.Fields{
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
})
f.handleDNSQuery(logger, w, query, startTime)
}

View File

@@ -38,7 +38,6 @@ import (
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/metrics"
"github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
@@ -46,14 +45,13 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updater"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
@@ -81,6 +79,7 @@ const (
var ErrResetConnection = fmt.Errorf("reset connection")
// EngineConfig is a config for the Engine
type EngineConfig struct {
WgPort int
WgIfaceName string
@@ -142,18 +141,6 @@ type EngineConfig struct {
LogPath string
}
// EngineServices holds the external service dependencies required by the Engine.
type EngineServices struct {
SignalClient signal.Client
MgmClient mgm.Client
RelayManager *relayClient.Manager
StatusRecorder *peer.Status
Checks []*mgmProto.Checks
StateManager *statemanager.Manager
UpdateManager *updater.Manager
ClientMetrics *metrics.ClientMetrics
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
type Engine struct {
// signal is a Signal Service client
@@ -211,10 +198,9 @@ type Engine struct {
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
relayManager *relayClient.Manager
stateManager *statemanager.Manager
portForwardManager *portforward.Manager
srWatcher *guard.SRWatcher
relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
@@ -223,7 +209,7 @@ type Engine struct {
flowManager nftypes.FlowManager
// auto-update
updateManager *updater.Manager
updateManager *updatemanager.Manager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
@@ -233,9 +219,6 @@ type Engine struct {
probeStunTurn *relay.StunTurnProbe
// clientMetrics collects and pushes metrics
clientMetrics *metrics.ClientMetrics
jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup
@@ -256,32 +239,34 @@ type localIpUpdater interface {
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
services EngineServices,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
stateManager *statemanager.Manager,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
portForwardManager: portforward.NewManager(),
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: signalClient,
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
mgmClient: mgmClient,
relayManager: relayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: statusRecorder,
stateManager: stateManager,
checks: checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -324,7 +309,7 @@ func (e *Engine) Stop() error {
}
if e.updateManager != nil {
e.updateManager.SetDownloadOnly()
e.updateManager.Stop()
}
log.Info("cleaning up status recorder states")
@@ -502,17 +487,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
for _, r := range routes {
if r.Network.Contains(ip) {
return true
}
}
}
return false
})
if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close()
@@ -524,11 +498,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return err
}
// Inject firewall into DNS server now that it's available.
// The DNS server is created before the firewall because the route manager
// depends on the DNS server, and the firewall depends on the wg interface.
e.dnsServer.SetFirewall(e.firewall)
e.udpMux, err = e.wgInterface.Up()
if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
@@ -540,13 +509,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack()
// Start after interface is up since port may have been resolved from 0 or changed if occupied
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
}()
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
@@ -597,6 +559,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return nil
}
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.handleAutoUpdateVersion(autoUpdateSettings, true)
}
func (e *Engine) createFirewall() error {
if e.config.DisableFirewall {
log.Infof("firewall is disabled")
@@ -824,30 +793,45 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
return nil
}
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
if e.updateManager == nil {
return
}
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
if autoUpdateSettings == nil {
return
}
if autoUpdateSettings.Version == disableAutoUpdate {
log.Infof("auto-update is disabled")
e.updateManager.SetDownloadOnly()
disabled := autoUpdateSettings.Version == disableAutoUpdate
// stop and cleanup if disabled
if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop()
e.updateManager = nil
return
}
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
return
}
// Start manager if needed
if e.updateManager == nil {
log.Infof("starting auto-update manager")
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
if err != nil {
return
}
e.updateManager = updateManager
e.updateManager.Start(e.ctx)
}
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
e.updateManager.SetVersion(autoUpdateSettings.Version)
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
duration := time.Since(started)
log.Infof("sync finished in %s", duration)
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
log.Infof("sync finished in %s", time.Since(started))
}()
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -858,7 +842,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
}
if update.GetNetbirdConfig() != nil {
@@ -1023,11 +1007,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return errors.New("wireguard interface is not initialized")
}
// Cannot update the IP address without restarting the engine because
// the firewall, route manager, and other components cache the old address
if e.wgInterface.Address().String() != conf.Address {
log.Infof("peer IP address changed from %s to %s, restarting client", e.wgInterface.Address().String(), conf.Address)
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return ErrResetConnection
log.Infof("peer IP address has changed from %s to %s", e.wgInterface.Address().String(), conf.Address)
}
if conf.GetSshConfig() != nil {
@@ -1095,7 +1078,6 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
ClientMetrics: e.clientMetrics,
RefreshStatus: func() {
e.RunHealthProbes(true)
},
@@ -1333,7 +1315,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// Test received (upstream) servers for availability right away instead of upon usage.
// If no server of a server group responds this will disable the respective handler and retry later.
go e.dnsServer.ProbeAvailability()
e.dnsServer.ProbeAvailability()
return nil
}
@@ -1550,13 +1533,11 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
PortForwardManager: e.portForwardManager,
MetricsRecorder: e.clientMetrics,
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
@@ -1713,12 +1694,6 @@ func (e *Engine) close() {
if e.rpManager != nil {
_ = e.rpManager.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
}
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
@@ -1822,7 +1797,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil
default:
@@ -1859,16 +1834,6 @@ func (e *Engine) GetExposeManager() *expose.Manager {
return e.exposeManager
}
// IsBlockInbound returns whether inbound connections are blocked.
func (e *Engine) IsBlockInbound() bool {
return e.config.BlockInbound
}
// GetClientMetrics returns the client metrics
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -251,6 +251,9 @@ func TestEngine_SSH(t *testing.T) {
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
relayMgr,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
@@ -260,13 +263,10 @@ func TestEngine_SSH(t *testing.T) {
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil,
nil,
)
engine.dnsServer = &dns.MockServer{
@@ -428,18 +428,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
wgIface := &MockWGIface{
NameFunc: func() string { return "utun102" },
@@ -652,18 +647,13 @@ func TestEngine_Sync(t *testing.T) {
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
@@ -822,18 +812,13 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
@@ -1029,18 +1014,13 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
@@ -1538,8 +1518,13 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
return nil, err
}
publicKey, err := mgmtClient.GetServerPublicKey()
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
@@ -1561,12 +1546,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{}), nil
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
e.ctx = ctx
return e, err
}

View File

@@ -4,34 +4,27 @@ import (
"context"
"time"
log "github.com/sirupsen/logrus"
mgm "github.com/netbirdio/netbird/shared/management/client"
log "github.com/sirupsen/logrus"
)
const (
renewTimeout = 10 * time.Second
)
const renewTimeout = 10 * time.Second
// Response holds the response from exposing a service.
type Response struct {
ServiceName string
ServiceURL string
Domain string
PortAutoAssigned bool
ServiceName string
ServiceURL string
Domain string
}
// Request holds the parameters for exposing a local service via the management server.
// It is part of the embed API surface and exposed via a type alias.
type Request struct {
NamePrefix string
Domain string
Port uint16
Protocol ProtocolType
Protocol int
Pin string
Password string
UserGroups []string
ListenPort uint16
}
type ManagementClient interface {
@@ -64,8 +57,6 @@ func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
return fromClientExposeResponse(resp), nil
}
// KeepAlive periodically renews the expose session for the given domain until the context is canceled or an error occurs.
// It is part of the embed API surface and exposed via a type alias.
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()

View File

@@ -86,7 +86,7 @@ func TestNewRequest(t *testing.T) {
exposeReq := NewRequest(req)
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
assert.Equal(t, ProtocolType(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
assert.Equal(t, "secret", exposeReq.Password, "password should match")
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")

View File

@@ -1,40 +0,0 @@
package expose
import (
"fmt"
"strings"
)
// ProtocolType represents the protocol used for exposing a service.
type ProtocolType int
const (
// ProtocolHTTP exposes the service as HTTP.
ProtocolHTTP ProtocolType = 0
// ProtocolHTTPS exposes the service as HTTPS.
ProtocolHTTPS ProtocolType = 1
// ProtocolTCP exposes the service as TCP.
ProtocolTCP ProtocolType = 2
// ProtocolUDP exposes the service as UDP.
ProtocolUDP ProtocolType = 3
// ProtocolTLS exposes the service as TLS.
ProtocolTLS ProtocolType = 4
)
// ParseProtocolType parses a protocol string into a ProtocolType.
func ParseProtocolType(s string) (ProtocolType, error) {
switch strings.ToLower(s) {
case "http":
return ProtocolHTTP, nil
case "https":
return ProtocolHTTPS, nil
case "tcp":
return ProtocolTCP, nil
case "udp":
return ProtocolUDP, nil
case "tls":
return ProtocolTLS, nil
default:
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", s)
}
}

View File

@@ -9,13 +9,12 @@ import (
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
return &Request{
Port: uint16(req.Port),
Protocol: ProtocolType(req.Protocol),
Protocol: int(req.Protocol),
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
Domain: req.Domain,
NamePrefix: req.NamePrefix,
ListenPort: uint16(req.ListenPort),
}
}
@@ -24,19 +23,17 @@ func toClientExposeRequest(req Request) mgm.ExposeRequest {
NamePrefix: req.NamePrefix,
Domain: req.Domain,
Port: req.Port,
Protocol: int(req.Protocol),
Protocol: req.Protocol,
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
ListenPort: req.ListenPort,
}
}
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
return &Response{
ServiceName: response.ServiceName,
Domain: response.Domain,
ServiceURL: response.ServiceURL,
PortAutoAssigned: response.PortAutoAssigned,
ServiceName: response.ServiceName,
Domain: response.Domain,
ServiceURL: response.ServiceURL,
}
}

View File

@@ -1,17 +0,0 @@
package metrics
// ConnectionType represents the type of peer connection
type ConnectionType string
const (
// ConnectionTypeICE represents a direct peer-to-peer connection using ICE
ConnectionTypeICE ConnectionType = "ice"
// ConnectionTypeRelay represents a relayed connection
ConnectionTypeRelay ConnectionType = "relay"
)
// String returns the string representation of the connection type
func (c ConnectionType) String() string {
return string(c)
}

View File

@@ -1,51 +0,0 @@
package metrics
import (
"net/url"
"strings"
)
// DeploymentType represents the type of NetBird deployment
type DeploymentType int
const (
// DeploymentTypeUnknown represents an unknown or uninitialized deployment type
DeploymentTypeUnknown DeploymentType = iota
// DeploymentTypeCloud represents a cloud-hosted NetBird deployment
DeploymentTypeCloud
// DeploymentTypeSelfHosted represents a self-hosted NetBird deployment
DeploymentTypeSelfHosted
)
// String returns the string representation of the deployment type
func (d DeploymentType) String() string {
switch d {
case DeploymentTypeCloud:
return "cloud"
case DeploymentTypeSelfHosted:
return "selfhosted"
default:
return "unknown"
}
}
// DetermineDeploymentType determines if the deployment is cloud or self-hosted
// based on the management URL string
func DetermineDeploymentType(managementURL string) DeploymentType {
if managementURL == "" {
return DeploymentTypeUnknown
}
u, err := url.Parse(managementURL)
if err != nil {
return DeploymentTypeSelfHosted
}
if strings.ToLower(u.Hostname()) == "api.netbird.io" {
return DeploymentTypeCloud
}
return DeploymentTypeSelfHosted
}

View File

@@ -1,93 +0,0 @@
package metrics
import (
"net/url"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
const (
// EnvMetricsPushEnabled controls whether collected metrics are pushed to the backend.
// Metrics collection itself is always active (for debug bundles).
// Disabled by default. Set NB_METRICS_PUSH_ENABLED=true to enable push.
EnvMetricsPushEnabled = "NB_METRICS_PUSH_ENABLED"
// EnvMetricsForceSending if set to true, skips remote configuration fetch and forces metric sending
EnvMetricsForceSending = "NB_METRICS_FORCE_SENDING"
// EnvMetricsConfigURL is the environment variable to override the metrics push config ServerAddress
EnvMetricsConfigURL = "NB_METRICS_CONFIG_URL"
// EnvMetricsServerURL is the environment variable to override the metrics server address.
// When set, this takes precedence over the server_url from remote push config.
EnvMetricsServerURL = "NB_METRICS_SERVER_URL"
// EnvMetricsInterval overrides the push interval from the remote config.
// Only affects how often metrics are pushed; remote config availability
// and version range checks are still respected.
// Format: duration string like "1h", "30m", "4h"
EnvMetricsInterval = "NB_METRICS_INTERVAL"
defaultMetricsConfigURL = "https://ingest.netbird.io/config"
)
// IsMetricsPushEnabled returns true if metrics push is enabled via NB_METRICS_PUSH_ENABLED env var.
// Disabled by default. Metrics collection is always active for debug bundles.
func IsMetricsPushEnabled() bool {
enabled, _ := strconv.ParseBool(os.Getenv(EnvMetricsPushEnabled))
return enabled
}
// getMetricsInterval returns the metrics push interval from NB_METRICS_INTERVAL env var.
// Returns 0 if not set or invalid.
func getMetricsInterval() time.Duration {
intervalStr := os.Getenv(EnvMetricsInterval)
if intervalStr == "" {
return 0
}
interval, err := time.ParseDuration(intervalStr)
if err != nil {
log.Warnf("invalid metrics interval from env %q: %v", intervalStr, err)
return 0
}
if interval <= 0 {
log.Warnf("invalid metrics interval from env %q: must be positive", intervalStr)
return 0
}
return interval
}
func isForceSending() bool {
force, _ := strconv.ParseBool(os.Getenv(EnvMetricsForceSending))
return force
}
// getMetricsConfigURL returns the URL to fetch push configuration from
func getMetricsConfigURL() string {
if envURL := os.Getenv(EnvMetricsConfigURL); envURL != "" {
return envURL
}
return defaultMetricsConfigURL
}
// getMetricsServerURL returns the metrics server URL from NB_METRICS_SERVER_URL env var.
// Returns nil if not set or invalid.
func getMetricsServerURL() *url.URL {
envURL := os.Getenv(EnvMetricsServerURL)
if envURL == "" {
return nil
}
parsed, err := url.ParseRequestURI(envURL)
if err != nil || parsed.Host == "" {
log.Warnf("invalid metrics server URL %q: must be an absolute HTTP(S) URL", envURL)
return nil
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
log.Warnf("invalid metrics server URL %q: unsupported scheme %q", envURL, parsed.Scheme)
return nil
}
return parsed
}

View File

@@ -1,219 +0,0 @@
package metrics
import (
"context"
"fmt"
"io"
"maps"
"slices"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
maxSampleAge = 5 * 24 * time.Hour // drop samples older than 5 days
maxBufferSize = 5 * 1024 * 1024 // drop oldest samples when estimated size exceeds 5 MB
// estimatedSampleSize is a rough per-sample memory estimate (measurement + tags + fields + timestamp)
estimatedSampleSize = 256
)
// influxSample is a single InfluxDB line protocol entry.
type influxSample struct {
measurement string
tags string
fields map[string]float64
timestamp time.Time
}
// influxDBMetrics collects metric events as timestamped samples.
// Each event is recorded with its exact timestamp, pushed once, then cleared.
type influxDBMetrics struct {
mu sync.Mutex
samples []influxSample
}
func newInfluxDBMetrics() metricsImplementation {
return &influxDBMetrics{}
}
func (m *influxDBMetrics) RecordConnectionStages(
_ context.Context,
agentInfo AgentInfo,
connectionPairID string,
connectionType ConnectionType,
isReconnection bool,
timestamps ConnectionStageTimestamps,
) {
var signalingReceivedToConnection, connectionToWgHandshake, totalDuration float64
if !timestamps.SignalingReceived.IsZero() && !timestamps.ConnectionReady.IsZero() {
signalingReceivedToConnection = timestamps.ConnectionReady.Sub(timestamps.SignalingReceived).Seconds()
}
if !timestamps.ConnectionReady.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() {
connectionToWgHandshake = timestamps.WgHandshakeSuccess.Sub(timestamps.ConnectionReady).Seconds()
}
if !timestamps.SignalingReceived.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() {
totalDuration = timestamps.WgHandshakeSuccess.Sub(timestamps.SignalingReceived).Seconds()
}
attemptType := "initial"
if isReconnection {
attemptType = "reconnection"
}
connTypeStr := connectionType.String()
tags := fmt.Sprintf("deployment_type=%s,connection_type=%s,attempt_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,connection_pair_id=%s",
agentInfo.DeploymentType.String(),
connTypeStr,
attemptType,
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
connectionPairID,
)
now := time.Now()
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_peer_connection",
tags: tags,
fields: map[string]float64{
"signaling_to_connection_seconds": signalingReceivedToConnection,
"connection_to_wg_handshake_seconds": connectionToWgHandshake,
"total_seconds": totalDuration,
},
timestamp: now,
})
m.trimLocked()
log.Tracef("peer connection metrics [%s, %s, %s]: signalingReceived→connection: %.3fs, connection→wg_handshake: %.3fs, total: %.3fs",
agentInfo.DeploymentType.String(), connTypeStr, attemptType, signalingReceivedToConnection, connectionToWgHandshake, totalDuration)
}
func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration) {
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
agentInfo.DeploymentType.String(),
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_sync",
tags: tags,
fields: map[string]float64{
"duration_seconds": duration.Seconds(),
},
timestamp: time.Now(),
})
m.trimLocked()
}
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
result := "success"
if !success {
result = "failure"
}
tags := fmt.Sprintf("deployment_type=%s,result=%s,version=%s,os=%s,arch=%s,peer_id=%s",
agentInfo.DeploymentType.String(),
result,
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_login",
tags: tags,
fields: map[string]float64{
"duration_seconds": duration.Seconds(),
},
timestamp: time.Now(),
})
m.trimLocked()
log.Tracef("login metrics [%s, %s]: duration=%.3fs", agentInfo.DeploymentType.String(), result, duration.Seconds())
}
// Export writes pending samples in InfluxDB line protocol format.
// Format: measurement,tag=val,tag=val field=val,field=val timestamp_ns
func (m *influxDBMetrics) Export(w io.Writer) error {
m.mu.Lock()
samples := make([]influxSample, len(m.samples))
copy(samples, m.samples)
m.mu.Unlock()
for _, s := range samples {
if _, err := fmt.Fprintf(w, "%s,%s ", s.measurement, s.tags); err != nil {
return err
}
sortedKeys := slices.Sorted(maps.Keys(s.fields))
first := true
for _, k := range sortedKeys {
if !first {
if _, err := fmt.Fprint(w, ","); err != nil {
return err
}
}
if _, err := fmt.Fprintf(w, "%s=%g", k, s.fields[k]); err != nil {
return err
}
first = false
}
if _, err := fmt.Fprintf(w, " %d\n", s.timestamp.UnixNano()); err != nil {
return err
}
}
return nil
}
// Reset clears pending samples after a successful push
func (m *influxDBMetrics) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.samples = m.samples[:0]
}
// trimLocked removes samples that exceed age or size limits.
// Must be called with m.mu held.
func (m *influxDBMetrics) trimLocked() {
now := time.Now()
// drop samples older than maxSampleAge
cutoff := 0
for cutoff < len(m.samples) && now.Sub(m.samples[cutoff].timestamp) > maxSampleAge {
cutoff++
}
if cutoff > 0 {
copy(m.samples, m.samples[cutoff:])
m.samples = m.samples[:len(m.samples)-cutoff]
log.Debugf("influxdb metrics: dropped %d samples older than %s", cutoff, maxSampleAge)
}
// drop oldest samples if estimated size exceeds maxBufferSize
maxSamples := maxBufferSize / estimatedSampleSize
if len(m.samples) > maxSamples {
drop := len(m.samples) - maxSamples
copy(m.samples, m.samples[drop:])
m.samples = m.samples[:maxSamples]
log.Debugf("influxdb metrics: dropped %d oldest samples to stay under %d MB size limit", drop, maxBufferSize/(1024*1024))
}
}

View File

@@ -1,229 +0,0 @@
package metrics
import (
"bytes"
"context"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInfluxDBMetrics_RecordAndExport(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
agentInfo := AgentInfo{
DeploymentType: DeploymentTypeCloud,
Version: "1.0.0",
OS: "linux",
Arch: "amd64",
peerID: "abc123",
}
ts := ConnectionStageTimestamps{
SignalingReceived: time.Now().Add(-3 * time.Second),
ConnectionReady: time.Now().Add(-2 * time.Second),
WgHandshakeSuccess: time.Now().Add(-1 * time.Second),
}
m.RecordConnectionStages(context.Background(), agentInfo, "pair123", ConnectionTypeICE, false, ts)
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "netbird_peer_connection,")
assert.Contains(t, output, "connection_to_wg_handshake_seconds=")
assert.Contains(t, output, "signaling_to_connection_seconds=")
assert.Contains(t, output, "total_seconds=")
}
func TestInfluxDBMetrics_ExportDeterministicFieldOrder(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
agentInfo := AgentInfo{
DeploymentType: DeploymentTypeCloud,
Version: "1.0.0",
OS: "linux",
Arch: "amd64",
peerID: "abc123",
}
ts := ConnectionStageTimestamps{
SignalingReceived: time.Now().Add(-3 * time.Second),
ConnectionReady: time.Now().Add(-2 * time.Second),
WgHandshakeSuccess: time.Now().Add(-1 * time.Second),
}
// Record multiple times and verify consistent field order
for i := 0; i < 10; i++ {
m.RecordConnectionStages(context.Background(), agentInfo, "pair123", ConnectionTypeICE, false, ts)
}
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
lines := strings.Split(strings.TrimSpace(buf.String()), "\n")
require.Len(t, lines, 10)
// Extract field portion from each line and verify they're all identical
var fieldSections []string
for _, line := range lines {
parts := strings.SplitN(line, " ", 3)
require.Len(t, parts, 3, "each line should have measurement, fields, timestamp")
fieldSections = append(fieldSections, parts[1])
}
for i := 1; i < len(fieldSections); i++ {
assert.Equal(t, fieldSections[0], fieldSections[i], "field order should be deterministic across samples")
}
// Fields should be alphabetically sorted
assert.True(t, strings.HasPrefix(fieldSections[0], "connection_to_wg_handshake_seconds="),
"fields should be sorted: connection_to_wg < signaling_to < total")
}
func TestInfluxDBMetrics_RecordSyncDuration(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
agentInfo := AgentInfo{
DeploymentType: DeploymentTypeSelfHosted,
Version: "2.0.0",
OS: "darwin",
Arch: "arm64",
peerID: "def456",
}
m.RecordSyncDuration(context.Background(), agentInfo, 1500*time.Millisecond)
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "netbird_sync,")
assert.Contains(t, output, "duration_seconds=1.5")
assert.Contains(t, output, "deployment_type=selfhosted")
}
func TestInfluxDBMetrics_Reset(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
agentInfo := AgentInfo{
DeploymentType: DeploymentTypeCloud,
Version: "1.0.0",
OS: "linux",
Arch: "amd64",
peerID: "abc123",
}
m.RecordSyncDuration(context.Background(), agentInfo, time.Second)
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
assert.NotEmpty(t, buf.String())
m.Reset()
buf.Reset()
err = m.Export(&buf)
require.NoError(t, err)
assert.Empty(t, buf.String(), "should be empty after reset")
}
func TestInfluxDBMetrics_ExportEmpty(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
assert.Empty(t, buf.String())
}
func TestInfluxDBMetrics_TrimByAge(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
m.mu.Lock()
m.samples = append(m.samples, influxSample{
measurement: "old",
tags: "t=1",
fields: map[string]float64{"v": 1},
timestamp: time.Now().Add(-maxSampleAge - time.Hour),
})
m.trimLocked()
remaining := len(m.samples)
m.mu.Unlock()
assert.Equal(t, 0, remaining, "old samples should be trimmed")
}
func TestInfluxDBMetrics_RecordLoginDuration(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
agentInfo := AgentInfo{
DeploymentType: DeploymentTypeCloud,
Version: "1.0.0",
OS: "linux",
Arch: "amd64",
peerID: "abc123",
}
m.RecordLoginDuration(context.Background(), agentInfo, 2500*time.Millisecond, true)
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "netbird_login,")
assert.Contains(t, output, "duration_seconds=2.5")
assert.Contains(t, output, "result=success")
}
func TestInfluxDBMetrics_RecordLoginDurationFailure(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
agentInfo := AgentInfo{
DeploymentType: DeploymentTypeSelfHosted,
Version: "1.0.0",
OS: "darwin",
Arch: "arm64",
peerID: "xyz789",
}
m.RecordLoginDuration(context.Background(), agentInfo, 5*time.Second, false)
var buf bytes.Buffer
err := m.Export(&buf)
require.NoError(t, err)
output := buf.String()
assert.Contains(t, output, "netbird_login,")
assert.Contains(t, output, "result=failure")
assert.Contains(t, output, "deployment_type=selfhosted")
}
func TestInfluxDBMetrics_TrimBySize(t *testing.T) {
m := newInfluxDBMetrics().(*influxDBMetrics)
maxSamples := maxBufferSize / estimatedSampleSize
m.mu.Lock()
for i := 0; i < maxSamples+100; i++ {
m.samples = append(m.samples, influxSample{
measurement: "test",
tags: "t=1",
fields: map[string]float64{"v": float64(i)},
timestamp: time.Now(),
})
}
m.trimLocked()
remaining := len(m.samples)
m.mu.Unlock()
assert.Equal(t, maxSamples, remaining, "should trim to max samples")
}

View File

@@ -1,16 +0,0 @@
# Copy to .env and adjust values before running docker compose
# InfluxDB admin (server-side only, never exposed to clients)
INFLUXDB_ADMIN_PASSWORD=changeme
INFLUXDB_ADMIN_TOKEN=changeme
# Grafana admin credentials
GRAFANA_ADMIN_USER=admin
GRAFANA_ADMIN_PASSWORD=changeme
# Remote config served by ingest at /config
# Set CONFIG_METRICS_SERVER_URL to the ingest server's public address to enable
CONFIG_METRICS_SERVER_URL=
CONFIG_VERSION_SINCE=0.0.0
CONFIG_VERSION_UNTIL=99.99.99
CONFIG_PERIOD_MINUTES=5

View File

@@ -1 +0,0 @@
.env

View File

@@ -1,194 +0,0 @@
# Client Metrics
Internal documentation for the NetBird client metrics system.
## Overview
Client metrics track connection performance and sync durations using InfluxDB line protocol (`influxdb.go`). Each event is pushed once then cleared.
Metrics collection is always active (for debug bundles). Push to backend is:
- Disabled by default (opt-in via `NB_METRICS_PUSH_ENABLED=true`)
- Managed at daemon layer (survives engine restarts)
## Architecture
### Layer Separation
```text
Daemon Layer (connect.go)
├─ Creates ClientMetrics instance once
├─ Starts/stops push lifecycle
└─ Updates AgentInfo on profile switch
Engine Layer (engine.go)
└─ Records metrics via ClientMetrics methods
```
### Ingest Server
Clients do not talk to InfluxDB directly. An ingest server sits between clients and InfluxDB:
```text
Client ──POST──▶ Ingest Server (:8087) ──▶ InfluxDB (internal)
├─ Validates line protocol
├─ Allowlists measurements, fields, and tags
├─ Rejects out-of-bound values
└─ Serves remote config at /config
```
- **No secret/token-based client auth** — the ingest server holds the InfluxDB token server-side. Clients must send a hashed peer ID via `X-Peer-ID` header.
- **InfluxDB is not exposed** — only accessible within the docker network
- Source: `ingest/main.go`
## Metrics Collected
### Connection Stage Timing
Measurement: `netbird_peer_connection`
| Field | Timestamps | Description |
|-------|-----------|-------------|
| `signaling_to_connection_seconds` | `SignalingReceived → ConnectionReady` | ICE/relay negotiation time after the first signal is received from the remote peer |
| `connection_to_wg_handshake_seconds` | `ConnectionReady → WgHandshakeSuccess` | WireGuard cryptographic handshake latency once the transport layer is ready |
| `total_seconds` | `SignalingReceived → WgHandshakeSuccess` | End-to-end connection time anchored at the first received signal |
Tags:
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
- `connection_type`: "ice" | "relay"
- `attempt_type`: "initial" | "reconnection"
- `version`: NetBird version string
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
**Note:** `SignalingReceived` is set when the first offer or answer arrives from the remote peer (in both initial and reconnection paths). It excludes the potentially unbounded wait for the remote peer to come online.
### Sync Duration
Measurement: `netbird_sync`
| Field | Description |
|-------|-------------|
| `duration_seconds` | Time to process a sync message from management server |
Tags:
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
- `version`: NetBird version string
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
### Login Duration
Measurement: `netbird_login`
| Field | Description |
|-------|-------------|
| `duration_seconds` | Time to complete the login/auth exchange with management server |
Tags:
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
- `result`: "success" | "failure"
- `version`: NetBird version string
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
## Buffer Limits
The InfluxDB backend limits in-memory sample storage to prevent unbounded growth when pushes fail:
- **Max age:** Samples older than 5 days are dropped
- **Max size:** Estimated buffer size capped at 5 MB (~20k samples)
## Configuration
### Client Environment Variables
| Variable | Default | Description |
|----------|---------|-------------|
| `NB_METRICS_PUSH_ENABLED` | `false` | Enable metrics push to backend |
| `NB_METRICS_SERVER_URL` | *(from remote config)* | Ingest server URL (e.g., `https://ingest.netbird.io`) |
| `NB_METRICS_INTERVAL` | *(from remote config)* | Push interval (e.g., "1m", "30m", "4h") |
| `NB_METRICS_FORCE_SENDING` | `false` | Skip remote config, push unconditionally |
| `NB_METRICS_CONFIG_URL` | `https://ingest.netbird.io/config` | Remote push config URL |
`NB_METRICS_SERVER_URL` and `NB_METRICS_INTERVAL` override their respective values but do not bypass remote config eligibility checks (version range). Use `NB_METRICS_FORCE_SENDING=true` to skip all remote config gating.
### Ingest Server Environment Variables
| Variable | Default | Description |
|----------|---------|-------------|
| `INGEST_LISTEN_ADDR` | `:8087` | Listen address |
| `INFLUXDB_URL` | `http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns` | InfluxDB write endpoint |
| `INFLUXDB_TOKEN` | *(required)* | InfluxDB auth token (server-side only) |
| `CONFIG_METRICS_SERVER_URL` | *(empty — disables /config)* | `server_url` in the remote config JSON (the URL clients push metrics to) |
| `CONFIG_VERSION_SINCE` | `0.0.0` | Minimum client version to push metrics |
| `CONFIG_VERSION_UNTIL` | `99.99.99` | Maximum client version to push metrics |
| `CONFIG_PERIOD_MINUTES` | `5` | Push interval in minutes |
The ingest server serves a remote config JSON at `GET /config` when `CONFIG_METRICS_SERVER_URL` is set. Clients can use `NB_METRICS_CONFIG_URL=http://<ingest>/config` to fetch it.
### Configuration Precedence
For URL and Interval, the precedence is:
1. **Environment variable** - `NB_METRICS_SERVER_URL` / `NB_METRICS_INTERVAL`
2. **Remote config** - fetched from `NB_METRICS_CONFIG_URL`
3. **Default** - 5 minute interval, URL from remote config
## Push Behavior
1. `StartPush()` spawns background goroutine with timer
2. First push happens immediately on startup
3. Periodically: `push()``Export()` → HTTP POST to ingest server
4. On failure: log error, continue (non-blocking)
5. On success: `Reset()` clears pushed samples
6. `StopPush()` cancels context and waits for goroutine
Samples are collected with exact timestamps, pushed once, then cleared. No data is resent.
## Local Development Setup
### 1. Configure and Start Services
```bash
# From this directory (client/internal/metrics/infra)
cp .env.example .env
# Edit .env to set INFLUXDB_ADMIN_PASSWORD, INFLUXDB_ADMIN_TOKEN, and GRAFANA_ADMIN_PASSWORD
docker compose up -d
```
This starts:
- **Ingest server** on http://localhost:8087 — accepts client metrics (requires `X-Peer-ID` header, no secret/token auth)
- **InfluxDB** — internal only, not exposed to host
- **Grafana** on http://localhost:3001
### 2. Configure Client
```bash
export NB_METRICS_PUSH_ENABLED=true
export NB_METRICS_FORCE_SENDING=true
export NB_METRICS_SERVER_URL=http://localhost:8087
export NB_METRICS_INTERVAL=1m
```
### 3. Run Client
```bash
cd ../../../..
go run ./client/ up
```
### 4. View in Grafana
- **InfluxDB dashboard:** http://localhost:3001/d/netbird-influxdb-metrics
### 5. Verify Data
```bash
# Query via InfluxDB (using admin token from .env)
docker compose exec influxdb influx query \
'from(bucket: "metrics") |> range(start: -1h)' \
--org netbird
# Check ingest server health
curl http://localhost:8087/health
```

View File

@@ -1,69 +0,0 @@
version: '3.8'
services:
ingest:
container_name: ingest
build:
context: ./ingest
ports:
- "8087:8087"
environment:
- INGEST_LISTEN_ADDR=:8087
- INFLUXDB_URL=http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns
- INFLUXDB_TOKEN=${INFLUXDB_ADMIN_TOKEN:?required}
- CONFIG_METRICS_SERVER_URL=${CONFIG_METRICS_SERVER_URL:-}
- CONFIG_VERSION_SINCE=${CONFIG_VERSION_SINCE:-0.0.0}
- CONFIG_VERSION_UNTIL=${CONFIG_VERSION_UNTIL:-99.99.99}
- CONFIG_PERIOD_MINUTES=${CONFIG_PERIOD_MINUTES:-5}
depends_on:
- influxdb
restart: unless-stopped
networks:
- metrics
influxdb:
container_name: influxdb
image: influxdb:2
# No ports exposed — only accessible within the metrics network
volumes:
- influxdb-data:/var/lib/influxdb2
- ./influxdb/scripts:/docker-entrypoint-initdb.d
environment:
- DOCKER_INFLUXDB_INIT_MODE=setup
- DOCKER_INFLUXDB_INIT_USERNAME=admin
- DOCKER_INFLUXDB_INIT_PASSWORD=${INFLUXDB_ADMIN_PASSWORD:?required}
- DOCKER_INFLUXDB_INIT_ORG=netbird
- DOCKER_INFLUXDB_INIT_BUCKET=metrics
- DOCKER_INFLUXDB_INIT_RETENTION=365d
- DOCKER_INFLUXDB_INIT_ADMIN_TOKEN=${INFLUXDB_ADMIN_TOKEN:-}
restart: unless-stopped
networks:
- metrics
grafana:
container_name: grafana
image: grafana/grafana:11.6.0
ports:
- "3001:3000"
environment:
- GF_SECURITY_ADMIN_USER=${GRAFANA_ADMIN_USER:-admin}
- GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD:?required}
- GF_USERS_ALLOW_SIGN_UP=false
- GF_INSTALL_PLUGINS=
- INFLUXDB_ADMIN_TOKEN=${INFLUXDB_ADMIN_TOKEN:-}
volumes:
- grafana-data:/var/lib/grafana
- ./grafana/provisioning:/etc/grafana/provisioning
depends_on:
- influxdb
restart: unless-stopped
networks:
- metrics
volumes:
influxdb-data:
grafana-data:
networks:
metrics:
driver: bridge

View File

@@ -1,12 +0,0 @@
apiVersion: 1
providers:
- name: 'NetBird Dashboards'
orgId: 1
folder: ''
type: file
disableDeletion: false
updateIntervalSeconds: 10
allowUiUpdates: true
options:
path: /etc/grafana/provisioning/dashboards/json

View File

@@ -1,280 +0,0 @@
{
"uid": "netbird-influxdb-metrics",
"title": "NetBird Client Metrics (InfluxDB)",
"tags": ["netbird", "connections", "influxdb"],
"timezone": "browser",
"panels": [
{
"id": 5,
"title": "Sync Duration Extremes",
"type": "stat",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 0
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> min()\n |> set(key: \"_field\", value: \"Min\")",
"refId": "A"
},
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> max()\n |> set(key: \"_field\", value: \"Max\")",
"refId": "B"
}
],
"fieldConfig": {
"defaults": {
"unit": "ms",
"min": 0
}
},
"options": {
"reduceOptions": {
"calcs": ["lastNotNull"]
},
"colorMode": "value",
"graphMode": "none",
"textMode": "auto"
}
},
{
"id": 6,
"title": "Total Connection Time Extremes",
"type": "stat",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 0
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> min()\n |> set(key: \"_field\", value: \"Min\")",
"refId": "A"
},
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> max()\n |> set(key: \"_field\", value: \"Max\")",
"refId": "B"
}
],
"fieldConfig": {
"defaults": {
"unit": "ms",
"min": 0
}
},
"options": {
"reduceOptions": {
"calcs": ["lastNotNull"]
},
"colorMode": "value",
"graphMode": "none",
"textMode": "auto"
}
},
{
"id": 1,
"title": "Sync Duration",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 8
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> set(key: \"_field\", value: \"Sync Duration\")",
"refId": "A"
}
],
"fieldConfig": {
"defaults": {
"unit": "ms",
"min": 0,
"custom": {
"drawStyle": "points",
"pointSize": 5
}
}
}
},
{
"id": 4,
"title": "ICE vs Relay",
"type": "piechart",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 8
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> drop(columns: [\"deployment_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> group(columns: [\"connection_pair_id\"])\n |> last()\n |> group(columns: [\"connection_type\"])\n |> count()",
"refId": "A"
}
],
"options": {
"reduceOptions": {
"calcs": ["lastNotNull"]
},
"pieType": "donut",
"tooltip": {
"mode": "multi"
}
}
},
{
"id": 2,
"title": "Connection Stage Durations (avg)",
"type": "bargauge",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 16
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"signaling_to_connection_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> mean()\n |> drop(columns: [\"_start\", \"_stop\", \"_measurement\", \"_time\", \"_field\"])\n |> rename(columns: {_value: \"Avg Signaling to Connection\"})",
"refId": "A"
},
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"connection_to_wg_handshake_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> mean()\n |> drop(columns: [\"_start\", \"_stop\", \"_measurement\", \"_time\", \"_field\"])\n |> rename(columns: {_value: \"Avg Connection to WG Handshake\"})",
"refId": "B"
}
],
"fieldConfig": {
"defaults": {
"unit": "ms",
"min": 0
}
},
"options": {
"reduceOptions": {
"calcs": ["lastNotNull"]
},
"orientation": "horizontal",
"displayMode": "gradient"
}
},
{
"id": 3,
"title": "Total Connection Time",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 16
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_peer_connection\" and r._field == \"total_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"connection_type\", \"attempt_type\", \"version\", \"os\", \"arch\", \"peer_id\", \"connection_pair_id\"])\n |> set(key: \"_field\", value: \"Total Connection Time\")",
"refId": "A"
}
],
"fieldConfig": {
"defaults": {
"unit": "ms",
"min": 0,
"custom": {
"drawStyle": "points",
"pointSize": 5
}
}
}
},
{
"id": 7,
"title": "Login Duration",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 24
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_login\" and r._field == \"duration_seconds\")\n |> map(fn: (r) => ({r with _value: r._value * 1000.0}))\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> set(key: \"_field\", value: \"Login Duration\")",
"refId": "A"
}
],
"fieldConfig": {
"defaults": {
"unit": "ms",
"min": 0,
"custom": {
"drawStyle": "points",
"pointSize": 5
}
}
}
},
{
"id": 8,
"title": "Login Success vs Failure",
"type": "piechart",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 24
},
"targets": [
{
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_login\" and r._field == \"duration_seconds\")\n |> drop(columns: [\"deployment_type\", \"version\", \"os\", \"arch\", \"peer_id\"])\n |> group(columns: [\"result\"])\n |> count()",
"refId": "A"
}
],
"options": {
"reduceOptions": {
"calcs": ["lastNotNull"]
},
"pieType": "donut",
"tooltip": {
"mode": "multi"
}
}
}
],
"schemaVersion": 27,
"version": 2,
"refresh": "30s"
}

View File

@@ -1,15 +0,0 @@
apiVersion: 1
datasources:
- name: InfluxDB
uid: influxdb
type: influxdb
access: proxy
url: http://influxdb:8086
editable: true
jsonData:
version: Flux
organization: netbird
defaultBucket: metrics
secureJsonData:
token: ${INFLUXDB_ADMIN_TOKEN}

View File

@@ -1,25 +0,0 @@
#!/bin/bash
# Creates a scoped InfluxDB read-only token for Grafana.
# Clients do not need a token — they push via the ingest server.
BUCKET_ID=$(influx bucket list --org netbird --name metrics --json | grep -oP '"id"\s*:\s*"\K[^"]+' | head -1)
ORG_ID=$(influx org list --name netbird --json | grep -oP '"id"\s*:\s*"\K[^"]+' | head -1)
if [[ -z "$BUCKET_ID" ]] || [[ -z "$ORG_ID" ]]; then
echo "ERROR: Could not determine bucket or org ID" >&2
echo "BUCKET_ID=$BUCKET_ID ORG_ID=$ORG_ID" >&2
exit 1
fi
# Create read-only token for Grafana
READ_TOKEN=$(influx auth create \
--org netbird \
--read-bucket "$BUCKET_ID" \
--description "Grafana read-only token" \
--json | grep -oP '"token"\s*:\s*"\K[^"]+' | head -1)
echo ""
echo "============================================"
echo "GRAFANA READ-ONLY TOKEN:"
echo "$READ_TOKEN"
echo "============================================"

View File

@@ -1,10 +0,0 @@
FROM golang:1.25-alpine AS build
WORKDIR /app
COPY go.mod main.go ./
RUN CGO_ENABLED=0 go build -o ingest .
FROM alpine:3.20
RUN adduser -D -H ingest
COPY --from=build /app/ingest /usr/local/bin/ingest
USER ingest
ENTRYPOINT ["ingest"]

View File

@@ -1,11 +0,0 @@
module github.com/netbirdio/netbird/client/internal/metrics/infra/ingest
go 1.25
require github.com/stretchr/testify v1.11.1
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -1,10 +0,0 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -1,355 +0,0 @@
package main
import (
"bytes"
"compress/gzip"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"strconv"
"strings"
"time"
)
const (
defaultListenAddr = ":8087"
defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns"
maxBodySize = 50 * 1024 * 1024 // 50 MB max request body
maxDurationSeconds = 300.0 // reject any duration field > 5 minutes
peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars
maxTagValueLength = 64 // reject tag values longer than this
)
type measurementSpec struct {
allowedFields map[string]bool
allowedTags map[string]bool
}
var allowedMeasurements = map[string]measurementSpec{
"netbird_peer_connection": {
allowedFields: map[string]bool{
"signaling_to_connection_seconds": true,
"connection_to_wg_handshake_seconds": true,
"total_seconds": true,
},
allowedTags: map[string]bool{
"deployment_type": true,
"connection_type": true,
"attempt_type": true,
"version": true,
"os": true,
"arch": true,
"peer_id": true,
"connection_pair_id": true,
},
},
"netbird_sync": {
allowedFields: map[string]bool{
"duration_seconds": true,
},
allowedTags: map[string]bool{
"deployment_type": true,
"version": true,
"os": true,
"arch": true,
"peer_id": true,
},
},
"netbird_login": {
allowedFields: map[string]bool{
"duration_seconds": true,
},
allowedTags: map[string]bool{
"deployment_type": true,
"result": true,
"version": true,
"os": true,
"arch": true,
"peer_id": true,
},
},
}
func main() {
listenAddr := envOr("INGEST_LISTEN_ADDR", defaultListenAddr)
influxURL := envOr("INFLUXDB_URL", defaultInfluxDBURL)
influxToken := os.Getenv("INFLUXDB_TOKEN")
if influxToken == "" {
log.Fatal("INFLUXDB_TOKEN is required")
}
client := &http.Client{Timeout: 10 * time.Second}
http.HandleFunc("/", handleIngest(client, influxURL, influxToken))
// Build config JSON once at startup from env vars
configJSON := buildConfigJSON()
if configJSON != nil {
log.Printf("serving remote config at /config")
}
http.HandleFunc("/config", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if configJSON == nil {
http.Error(w, "config not configured", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(configJSON) //nolint:errcheck
})
http.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "ok") //nolint:errcheck
})
log.Printf("ingest server listening on %s, forwarding to %s", listenAddr, influxURL)
if err := http.ListenAndServe(listenAddr, nil); err != nil { //nolint:gosec
log.Fatal(err)
}
}
func handleIngest(client *http.Client, influxURL, influxToken string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := validateAuth(r); err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
body, err := readBody(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if len(body) > maxBodySize {
http.Error(w, "body too large", http.StatusRequestEntityTooLarge)
return
}
validated, err := validateLineProtocol(body)
if err != nil {
log.Printf("WARN validation failed from %s: %v", r.RemoteAddr, err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
forwardToInflux(w, r, client, influxURL, influxToken, validated)
}
}
func forwardToInflux(w http.ResponseWriter, r *http.Request, client *http.Client, influxURL, influxToken string, body []byte) {
req, err := http.NewRequestWithContext(r.Context(), http.MethodPost, influxURL, bytes.NewReader(body))
if err != nil {
log.Printf("ERROR create request: %v", err)
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
req.Header.Set("Content-Type", "text/plain; charset=utf-8")
req.Header.Set("Authorization", "Token "+influxToken)
resp, err := client.Do(req)
if err != nil {
log.Printf("ERROR forward to influxdb: %v", err)
http.Error(w, "upstream error", http.StatusBadGateway)
return
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body) //nolint:errcheck
}
// validateAuth checks that the X-Peer-ID header contains a valid hashed peer ID.
func validateAuth(r *http.Request) error {
peerID := r.Header.Get("X-Peer-ID")
if peerID == "" {
return fmt.Errorf("missing X-Peer-ID header")
}
if len(peerID) != peerIDLength {
return fmt.Errorf("invalid X-Peer-ID header length")
}
if _, err := hex.DecodeString(peerID); err != nil {
return fmt.Errorf("invalid X-Peer-ID header format")
}
return nil
}
// readBody reads the request body, decompressing gzip if Content-Encoding indicates it.
func readBody(r *http.Request) ([]byte, error) {
reader := io.LimitReader(r.Body, maxBodySize+1)
if r.Header.Get("Content-Encoding") == "gzip" {
gz, err := gzip.NewReader(reader)
if err != nil {
return nil, fmt.Errorf("invalid gzip: %w", err)
}
defer gz.Close()
reader = io.LimitReader(gz, maxBodySize+1)
}
return io.ReadAll(reader)
}
// validateLineProtocol parses InfluxDB line protocol lines,
// whitelists measurements and fields, and checks value bounds.
func validateLineProtocol(body []byte) ([]byte, error) {
lines := strings.Split(strings.TrimSpace(string(body)), "\n")
var valid []string
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
if err := validateLine(line); err != nil {
return nil, err
}
valid = append(valid, line)
}
if len(valid) == 0 {
return nil, fmt.Errorf("no valid lines")
}
return []byte(strings.Join(valid, "\n") + "\n"), nil
}
func validateLine(line string) error {
// line protocol: measurement,tag=val,tag=val field=val,field=val timestamp
parts := strings.SplitN(line, " ", 3)
if len(parts) < 2 {
return fmt.Errorf("invalid line protocol: %q", truncate(line, 100))
}
// parts[0] is "measurement,tag=val,tag=val"
measurementAndTags := strings.Split(parts[0], ",")
measurement := measurementAndTags[0]
spec, ok := allowedMeasurements[measurement]
if !ok {
return fmt.Errorf("unknown measurement: %q", measurement)
}
// Validate tags (everything after measurement name in parts[0])
for _, tagPair := range measurementAndTags[1:] {
if err := validateTag(tagPair, measurement, spec.allowedTags); err != nil {
return err
}
}
// Validate fields
for _, pair := range strings.Split(parts[1], ",") {
if err := validateField(pair, measurement, spec.allowedFields); err != nil {
return err
}
}
return nil
}
func validateTag(pair, measurement string, allowedTags map[string]bool) error {
kv := strings.SplitN(pair, "=", 2)
if len(kv) != 2 {
return fmt.Errorf("invalid tag: %q", pair)
}
tagName := kv[0]
if !allowedTags[tagName] {
return fmt.Errorf("unknown tag %q in measurement %q", tagName, measurement)
}
if len(kv[1]) > maxTagValueLength {
return fmt.Errorf("tag value too long for %q: %d > %d", tagName, len(kv[1]), maxTagValueLength)
}
return nil
}
func validateField(pair, measurement string, allowedFields map[string]bool) error {
kv := strings.SplitN(pair, "=", 2)
if len(kv) != 2 {
return fmt.Errorf("invalid field: %q", pair)
}
fieldName := kv[0]
if !allowedFields[fieldName] {
return fmt.Errorf("unknown field %q in measurement %q", fieldName, measurement)
}
val, err := strconv.ParseFloat(kv[1], 64)
if err != nil {
return fmt.Errorf("invalid field value %q for %q", kv[1], fieldName)
}
if val < 0 {
return fmt.Errorf("negative value for %q: %g", fieldName, val)
}
if strings.HasSuffix(fieldName, "_seconds") && val > maxDurationSeconds {
return fmt.Errorf("%q too large: %g > %g", fieldName, val, maxDurationSeconds)
}
return nil
}
// buildConfigJSON builds the remote config JSON from env vars.
// Returns nil if required vars are not set.
func buildConfigJSON() []byte {
serverURL := os.Getenv("CONFIG_METRICS_SERVER_URL")
versionSince := envOr("CONFIG_VERSION_SINCE", "0.0.0")
versionUntil := envOr("CONFIG_VERSION_UNTIL", "99.99.99")
periodMinutes := envOr("CONFIG_PERIOD_MINUTES", "5")
if serverURL == "" {
return nil
}
period, err := strconv.Atoi(periodMinutes)
if err != nil || period <= 0 {
log.Printf("WARN invalid CONFIG_PERIOD_MINUTES: %q, using 5", periodMinutes)
period = 5
}
cfg := map[string]any{
"server_url": serverURL,
"version-since": versionSince,
"version-until": versionUntil,
"period_minutes": period,
}
data, err := json.Marshal(cfg)
if err != nil {
log.Printf("ERROR failed to marshal config: %v", err)
return nil
}
return data
}
func envOr(key, defaultVal string) string {
if v := os.Getenv(key); v != "" {
return v
}
return defaultVal
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}

View File

@@ -1,124 +0,0 @@
package main
import (
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateLine_ValidPeerConnection(t *testing.T) {
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abcdef0123456789,connection_pair_id=pair1234 signaling_to_connection_seconds=1.5,connection_to_wg_handshake_seconds=0.5,total_seconds=2 1234567890`
assert.NoError(t, validateLine(line))
}
func TestValidateLine_ValidSync(t *testing.T) {
line := `netbird_sync,deployment_type=selfhosted,version=2.0.0,os=darwin,arch=arm64,peer_id=abcdef0123456789 duration_seconds=1.5 1234567890`
assert.NoError(t, validateLine(line))
}
func TestValidateLine_ValidLogin(t *testing.T) {
line := `netbird_login,deployment_type=cloud,result=success,version=1.0.0,os=linux,arch=amd64,peer_id=abcdef0123456789 duration_seconds=3.2 1234567890`
assert.NoError(t, validateLine(line))
}
func TestValidateLine_UnknownMeasurement(t *testing.T) {
line := `unknown_metric,foo=bar value=1 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown measurement")
}
func TestValidateLine_UnknownTag(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,evil_tag=injected,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown tag")
}
func TestValidateLine_UnknownField(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc injected_field=1 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown field")
}
func TestValidateLine_NegativeValue(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=-1.5 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "negative")
}
func TestValidateLine_DurationTooLarge(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=999 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")
}
func TestValidateLine_TotalSecondsTooLarge(t *testing.T) {
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=500 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")
}
func TestValidateLine_TagValueTooLong(t *testing.T) {
longTag := strings.Repeat("a", maxTagValueLength+1)
line := `netbird_sync,deployment_type=` + longTag + `,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "tag value too long")
}
func TestValidateLineProtocol_MultipleLines(t *testing.T) {
body := []byte(
"netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890\n" +
"netbird_login,deployment_type=cloud,result=success,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=2.0 1234567890\n",
)
validated, err := validateLineProtocol(body)
require.NoError(t, err)
assert.Contains(t, string(validated), "netbird_sync")
assert.Contains(t, string(validated), "netbird_login")
}
func TestValidateLineProtocol_RejectsOnBadLine(t *testing.T) {
body := []byte(
"netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=1.5 1234567890\n" +
"evil_metric,foo=bar value=1 1234567890\n",
)
_, err := validateLineProtocol(body)
require.Error(t, err)
}
func TestValidateAuth(t *testing.T) {
tests := []struct {
name string
peerID string
wantErr bool
}{
{"valid hex", "abcdef0123456789", false},
{"empty", "", true},
{"too short", "abcdef01234567", true},
{"too long", "abcdef01234567890", true},
{"invalid hex", "ghijklmnopqrstuv", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, _ := http.NewRequest(http.MethodPost, "/", nil)
if tt.peerID != "" {
r.Header.Set("X-Peer-ID", tt.peerID)
}
err := validateAuth(r)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}

View File

@@ -1,224 +0,0 @@
package metrics
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/metrics/remoteconfig"
)
// AgentInfo holds static information about the agent
type AgentInfo struct {
DeploymentType DeploymentType
Version string
OS string // runtime.GOOS (linux, darwin, windows, etc.)
Arch string // runtime.GOARCH (amd64, arm64, etc.)
peerID string // anonymised peer identifier (SHA-256 of WireGuard public key)
}
// peerIDFromPublicKey returns a truncated SHA-256 hash (8 bytes / 16 hex chars) of the given WireGuard public key.
func peerIDFromPublicKey(pubKey string) string {
hash := sha256.Sum256([]byte(pubKey))
return hex.EncodeToString(hash[:8])
}
// connectionPairID returns a deterministic identifier for a connection between two peers.
// It sorts the two peer IDs before hashing so the same pair always produces the same ID
// regardless of which side computes it.
func connectionPairID(peerID1, peerID2 string) string {
a, b := peerID1, peerID2
if a > b {
a, b = b, a
}
hash := sha256.Sum256([]byte(a + b))
return hex.EncodeToString(hash[:8])
}
// metricsImplementation defines the internal interface for metrics implementations
type metricsImplementation interface {
// RecordConnectionStages records connection stage metrics from timestamps
RecordConnectionStages(
ctx context.Context,
agentInfo AgentInfo,
connectionPairID string,
connectionType ConnectionType,
isReconnection bool,
timestamps ConnectionStageTimestamps,
)
// RecordSyncDuration records how long it took to process a sync message
RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration)
// RecordLoginDuration records how long the login to management took
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
// Export exports metrics in InfluxDB line protocol format
Export(w io.Writer) error
// Reset clears all collected metrics
Reset()
}
type ClientMetrics struct {
impl metricsImplementation
agentInfo AgentInfo
mu sync.RWMutex
push *Push
pushMu sync.Mutex
wg sync.WaitGroup
pushCancel context.CancelFunc
}
// ConnectionStageTimestamps holds timestamps for each connection stage
type ConnectionStageTimestamps struct {
SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection)
ConnectionReady time.Time
WgHandshakeSuccess time.Time
}
// String returns a human-readable representation of the connection stage timestamps
func (c ConnectionStageTimestamps) String() string {
return fmt.Sprintf("ConnectionStageTimestamps{SignalingReceived=%v, ConnectionReady=%v, WgHandshakeSuccess=%v}",
c.SignalingReceived.Format(time.RFC3339Nano),
c.ConnectionReady.Format(time.RFC3339Nano),
c.WgHandshakeSuccess.Format(time.RFC3339Nano),
)
}
// RecordConnectionStages calculates stage durations from timestamps and records them.
// remotePubKey is the remote peer's WireGuard public key; it will be hashed for anonymisation.
func (c *ClientMetrics) RecordConnectionStages(
ctx context.Context,
remotePubKey string,
connectionType ConnectionType,
isReconnection bool,
timestamps ConnectionStageTimestamps,
) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
remotePeerID := peerIDFromPublicKey(remotePubKey)
pairID := connectionPairID(agentInfo.peerID, remotePeerID)
c.impl.RecordConnectionStages(ctx, agentInfo, pairID, connectionType, isReconnection, timestamps)
}
// RecordSyncDuration records the duration of sync message processing
func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Duration) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
}
// RecordLoginDuration records how long the login to management server took
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordLoginDuration(ctx, agentInfo, duration, success)
}
// UpdateAgentInfo updates the agent information (e.g., when switching profiles).
// publicKey is the WireGuard public key; it will be hashed for anonymisation.
func (c *ClientMetrics) UpdateAgentInfo(agentInfo AgentInfo, publicKey string) {
if c == nil {
return
}
agentInfo.peerID = peerIDFromPublicKey(publicKey)
c.mu.Lock()
c.agentInfo = agentInfo
c.mu.Unlock()
c.pushMu.Lock()
push := c.push
c.pushMu.Unlock()
if push != nil {
push.SetPeerID(agentInfo.peerID)
}
}
// Export exports metrics to the writer
func (c *ClientMetrics) Export(w io.Writer) error {
if c == nil {
return nil
}
return c.impl.Export(w)
}
// StartPush starts periodic pushing of metrics with the given configuration
// Precedence: PushConfig.ServerAddress > remote config server_url
func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) {
if c == nil {
return
}
c.pushMu.Lock()
defer c.pushMu.Unlock()
if c.push != nil {
log.Warnf("metrics push already running")
return
}
c.mu.RLock()
agentVersion := c.agentInfo.Version
peerID := c.agentInfo.peerID
c.mu.RUnlock()
configManager := remoteconfig.NewManager(getMetricsConfigURL(), remoteconfig.DefaultMinRefreshInterval)
push, err := NewPush(c.impl, configManager, config, agentVersion)
if err != nil {
log.Errorf("failed to create metrics push: %v", err)
return
}
push.SetPeerID(peerID)
ctx, cancel := context.WithCancel(ctx)
c.pushCancel = cancel
c.wg.Add(1)
go func() {
defer c.wg.Done()
push.Start(ctx)
}()
c.push = push
}
func (c *ClientMetrics) StopPush() {
if c == nil {
return
}
c.pushMu.Lock()
defer c.pushMu.Unlock()
if c.push == nil {
return
}
c.pushCancel()
c.wg.Wait()
c.push = nil
}

View File

@@ -1,11 +0,0 @@
//go:build !js
package metrics
// NewClientMetrics creates a new ClientMetrics instance
func NewClientMetrics(agentInfo AgentInfo) *ClientMetrics {
return &ClientMetrics{
impl: newInfluxDBMetrics(),
agentInfo: agentInfo,
}
}

View File

@@ -1,8 +0,0 @@
//go:build js
package metrics
// NewClientMetrics returns nil on WASM builds — all ClientMetrics methods are nil-safe.
func NewClientMetrics(AgentInfo) *ClientMetrics {
return nil
}

View File

@@ -1,289 +0,0 @@
package metrics
import (
"bytes"
"compress/gzip"
"context"
"fmt"
"net/http"
"net/url"
"sync"
"time"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/metrics/remoteconfig"
)
const (
// defaultPushInterval is the default interval for pushing metrics
defaultPushInterval = 5 * time.Minute
)
// defaultMetricsServerURL is used as fallback when NB_METRICS_FORCE_SENDING is true
var defaultMetricsServerURL *url.URL
func init() {
defaultMetricsServerURL, _ = url.Parse("https://ingest.netbird.io")
}
// PushConfig holds configuration for metrics push
type PushConfig struct {
// ServerAddress is the metrics server URL. If nil, uses remote config server_url.
ServerAddress *url.URL
// Interval is how often to push metrics. If 0, uses remote config interval or defaultPushInterval.
Interval time.Duration
// ForceSending skips remote configuration fetch and version checks, pushing unconditionally.
ForceSending bool
}
// PushConfigFromEnv builds a PushConfig from environment variables.
func PushConfigFromEnv() PushConfig {
config := PushConfig{}
config.ForceSending = isForceSending()
config.ServerAddress = getMetricsServerURL()
config.Interval = getMetricsInterval()
return config
}
// remoteConfigProvider abstracts remote push config fetching for testability
type remoteConfigProvider interface {
RefreshIfNeeded(ctx context.Context) *remoteconfig.Config
}
// Push handles periodic pushing of metrics
type Push struct {
metrics metricsImplementation
configManager remoteConfigProvider
agentVersion *goversion.Version
peerID string
peerMu sync.RWMutex
client *http.Client
cfgForceSending bool
cfgInterval time.Duration
cfgAddress *url.URL
}
// NewPush creates a new Push instance with configuration resolution
func NewPush(metrics metricsImplementation, configManager remoteConfigProvider, config PushConfig, agentVersion string) (*Push, error) {
var cfgInterval time.Duration
var cfgAddress *url.URL
if config.ForceSending {
cfgInterval = config.Interval
if config.Interval <= 0 {
cfgInterval = defaultPushInterval
}
cfgAddress = config.ServerAddress
if cfgAddress == nil {
cfgAddress = defaultMetricsServerURL
}
} else {
cfgAddress = config.ServerAddress
if config.Interval < 0 {
log.Warnf("negative metrics push interval %s", config.Interval)
} else {
cfgInterval = config.Interval
}
}
parsedVersion, err := goversion.NewVersion(agentVersion)
if err != nil {
if !config.ForceSending {
return nil, fmt.Errorf("parse agent version %q: %w", agentVersion, err)
}
}
return &Push{
metrics: metrics,
configManager: configManager,
agentVersion: parsedVersion,
cfgForceSending: config.ForceSending,
cfgInterval: cfgInterval,
cfgAddress: cfgAddress,
client: &http.Client{
Timeout: 10 * time.Second,
},
}, nil
}
// SetPeerID updates the hashed peer ID used for the Authorization header.
func (p *Push) SetPeerID(peerID string) {
p.peerMu.Lock()
p.peerID = peerID
p.peerMu.Unlock()
}
// Start starts the periodic push loop.
// The env interval override controls tick frequency but does not bypass remote config
// version gating. Use ForceSending to skip remote config entirely.
func (p *Push) Start(ctx context.Context) {
// Log initial state
switch {
case p.cfgForceSending:
log.Infof("started metrics push with force sending to %s, interval %s", p.cfgAddress, p.cfgInterval)
case p.cfgAddress != nil:
log.Infof("started metrics push with server URL override: %s", p.cfgAddress.String())
default:
log.Infof("started metrics push, server URL will be resolved from remote config")
}
timer := time.NewTimer(0) // fire immediately on first iteration
defer timer.Stop()
for {
select {
case <-ctx.Done():
log.Debug("stopping metrics push")
return
case <-timer.C:
}
pushURL, interval := p.resolve(ctx)
if pushURL != "" {
if err := p.push(ctx, pushURL); err != nil {
log.Errorf("failed to push metrics: %v", err)
}
}
if interval <= 0 {
interval = defaultPushInterval
}
timer.Reset(interval)
}
}
// resolve returns the push URL and interval for the next cycle.
// Returns empty pushURL to skip this cycle.
func (p *Push) resolve(ctx context.Context) (pushURL string, interval time.Duration) {
if p.cfgForceSending {
return p.resolveServerURL(nil), p.cfgInterval
}
config := p.configManager.RefreshIfNeeded(ctx)
if config == nil {
log.Debug("no metrics push config available, waiting to retry")
return "", defaultPushInterval
}
// prefer env variables instead of remote config
if p.cfgInterval > 0 {
interval = p.cfgInterval
} else {
interval = config.Interval
}
if !isVersionInRange(p.agentVersion, config.VersionSince, config.VersionUntil) {
log.Debugf("agent version %s not in range [%s, %s), skipping metrics push",
p.agentVersion, config.VersionSince, config.VersionUntil)
return "", interval
}
pushURL = p.resolveServerURL(&config.ServerURL)
if pushURL == "" {
log.Warn("no metrics server URL available, skipping push")
}
return pushURL, interval
}
// push exports metrics and sends them to the metrics server
func (p *Push) push(ctx context.Context, pushURL string) error {
// Export metrics without clearing
var buf bytes.Buffer
if err := p.metrics.Export(&buf); err != nil {
return fmt.Errorf("export metrics: %w", err)
}
// Don't push if there are no metrics
if buf.Len() == 0 {
log.Tracef("no metrics to push")
return nil
}
// Gzip compress the body
compressed, err := gzipCompress(buf.Bytes())
if err != nil {
return fmt.Errorf("gzip compress: %w", err)
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, "POST", pushURL, compressed)
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "text/plain; charset=utf-8")
req.Header.Set("Content-Encoding", "gzip")
p.peerMu.RLock()
peerID := p.peerID
p.peerMu.RUnlock()
if peerID != "" {
req.Header.Set("X-Peer-ID", peerID)
}
// Send request
resp, err := p.client.Do(req)
if err != nil {
return fmt.Errorf("send request: %w", err)
}
defer func() {
if resp.Body == nil {
return
}
if err := resp.Body.Close(); err != nil {
log.Warnf("failed to close response body: %v", err)
}
}()
// Check response status
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("push failed with status %d", resp.StatusCode)
}
log.Debugf("successfully pushed metrics to %s", pushURL)
p.metrics.Reset()
return nil
}
// resolveServerURL determines the push URL.
// Precedence: envAddress (env var) > remote config server_url
func (p *Push) resolveServerURL(remoteServerURL *url.URL) string {
var baseURL *url.URL
if p.cfgAddress != nil {
baseURL = p.cfgAddress
} else {
baseURL = remoteServerURL
}
if baseURL == nil {
return ""
}
return baseURL.String()
}
// gzipCompress compresses data using gzip and returns the compressed buffer.
func gzipCompress(data []byte) (*bytes.Buffer, error) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
if _, err := gz.Write(data); err != nil {
_ = gz.Close()
return nil, err
}
if err := gz.Close(); err != nil {
return nil, err
}
return &buf, nil
}
// isVersionInRange checks if current falls within [since, until)
func isVersionInRange(current, since, until *goversion.Version) bool {
return !current.LessThan(since) && current.LessThan(until)
}

View File

@@ -1,343 +0,0 @@
package metrics
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"sync/atomic"
"testing"
"time"
goversion "github.com/hashicorp/go-version"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/metrics/remoteconfig"
)
func mustVersion(s string) *goversion.Version {
v, err := goversion.NewVersion(s)
if err != nil {
panic(err)
}
return v
}
func mustURL(s string) url.URL {
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return *u
}
func parseURL(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}
func testConfig(serverURL, since, until string, period time.Duration) *remoteconfig.Config {
return &remoteconfig.Config{
ServerURL: mustURL(serverURL),
VersionSince: mustVersion(since),
VersionUntil: mustVersion(until),
Interval: period,
}
}
// mockConfigProvider implements remoteConfigProvider for testing
type mockConfigProvider struct {
config *remoteconfig.Config
}
func (m *mockConfigProvider) RefreshIfNeeded(_ context.Context) *remoteconfig.Config {
return m.config
}
// mockMetrics implements metricsImplementation for testing
type mockMetrics struct {
exportData string
}
func (m *mockMetrics) RecordConnectionStages(_ context.Context, _ AgentInfo, _ string, _ ConnectionType, _ bool, _ ConnectionStageTimestamps) {
}
func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) {
}
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
}
func (m *mockMetrics) Export(w io.Writer) error {
if m.exportData != "" {
_, err := w.Write([]byte(m.exportData))
return err
}
return nil
}
func (m *mockMetrics) Reset() {
}
func TestPush_OverrideIntervalPushes(t *testing.T) {
var pushCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pushCount.Add(1)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 60*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{
Interval: 50 * time.Millisecond,
ServerAddress: parseURL(server.URL),
}, "1.0.0")
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
push.Start(ctx)
close(done)
}()
require.Eventually(t, func() bool {
return pushCount.Load() >= 3
}, 2*time.Second, 10*time.Millisecond)
cancel()
<-done
}
func TestPush_RemoteConfigVersionInRange(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 1*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{}, "1.5.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.NotEmpty(t, pushURL)
assert.Equal(t, 1*time.Minute, interval)
}
func TestPush_RemoteConfigVersionOutOfRange(t *testing.T) {
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: testConfig("http://localhost", "1.0.0", "1.5.0", 1*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{}, "2.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.Empty(t, pushURL)
assert.Equal(t, 1*time.Minute, interval)
}
func TestPush_NoConfigReturnsDefault(t *testing.T) {
metrics := &mockMetrics{}
configProvider := &mockConfigProvider{config: nil}
push, err := NewPush(metrics, configProvider, PushConfig{}, "1.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.Empty(t, pushURL)
assert.Equal(t, defaultPushInterval, interval)
}
func TestPush_OverrideIntervalRespectsVersionCheck(t *testing.T) {
metrics := &mockMetrics{}
configProvider := &mockConfigProvider{config: testConfig("http://localhost", "3.0.0", "4.0.0", 60*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{
Interval: 30 * time.Second,
ServerAddress: parseURL("http://localhost"),
}, "1.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.Empty(t, pushURL) // version out of range
assert.Equal(t, 30*time.Second, interval) // but uses override interval
}
func TestPush_OverrideIntervalUsedWhenVersionInRange(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{}
configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 60*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{
Interval: 30 * time.Second,
}, "1.5.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.NotEmpty(t, pushURL)
assert.Equal(t, 30*time.Second, interval)
}
func TestPush_NoMetricsSkipsPush(t *testing.T) {
var pushCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pushCount.Add(1)
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: ""} // no metrics to export
configProvider := &mockConfigProvider{config: nil}
push, err := NewPush(metrics, configProvider, PushConfig{}, "1.0.0")
require.NoError(t, err)
err = push.push(context.Background(), server.URL)
assert.NoError(t, err)
assert.Equal(t, int32(0), pushCount.Load())
}
func TestPush_ServerURLFromRemoteConfig(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 1*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{}, "1.5.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.Contains(t, pushURL, server.URL)
assert.Equal(t, 1*time.Minute, interval)
}
func TestPush_ServerAddressOverridesTakePrecedenceOverRemoteConfig(t *testing.T) {
overrideServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer overrideServer.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: testConfig("http://remote-config-server", "1.0.0", "2.0.0", 1*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{
ServerAddress: parseURL(overrideServer.URL),
}, "1.5.0")
require.NoError(t, err)
pushURL, _ := push.resolve(context.Background())
assert.Contains(t, pushURL, overrideServer.URL)
assert.NotContains(t, pushURL, "remote-config-server")
}
func TestPush_OverrideIntervalWithoutOverrideURL_UsesRemoteConfigURL(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: testConfig(server.URL, "1.0.0", "2.0.0", 60*time.Minute)}
push, err := NewPush(metrics, configProvider, PushConfig{
Interval: 30 * time.Second,
}, "1.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.Contains(t, pushURL, server.URL)
assert.Equal(t, 30*time.Second, interval)
}
func TestPush_NoConfigSkipsPush(t *testing.T) {
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: nil}
push, err := NewPush(metrics, configProvider, PushConfig{
Interval: 30 * time.Second,
}, "1.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.Empty(t, pushURL)
assert.Equal(t, defaultPushInterval, interval) // no config available, use default retry interval
}
func TestPush_ForceSendingSkipsRemoteConfig(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: nil}
push, err := NewPush(metrics, configProvider, PushConfig{
ForceSending: true,
Interval: 1 * time.Minute,
ServerAddress: parseURL(server.URL),
}, "1.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.NotEmpty(t, pushURL)
assert.Equal(t, 1*time.Minute, interval)
}
func TestPush_ForceSendingUsesDefaultInterval(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
metrics := &mockMetrics{exportData: "test_metric 1\n"}
configProvider := &mockConfigProvider{config: nil}
push, err := NewPush(metrics, configProvider, PushConfig{
ForceSending: true,
ServerAddress: parseURL(server.URL),
}, "1.0.0")
require.NoError(t, err)
pushURL, interval := push.resolve(context.Background())
assert.NotEmpty(t, pushURL)
assert.Equal(t, defaultPushInterval, interval)
}
func TestIsVersionInRange(t *testing.T) {
tests := []struct {
name string
current string
since string
until string
expected bool
}{
{"at lower bound inclusive", "1.2.2", "1.2.2", "1.2.3", true},
{"in range", "1.2.2", "1.2.0", "1.3.0", true},
{"at upper bound exclusive", "1.2.3", "1.2.2", "1.2.3", false},
{"below range", "1.2.1", "1.2.2", "1.2.3", false},
{"above range", "1.3.0", "1.2.2", "1.2.3", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isVersionInRange(mustVersion(tt.current), mustVersion(tt.since), mustVersion(tt.until)))
})
}
}

View File

@@ -1,149 +0,0 @@
package remoteconfig
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sync"
"time"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
)
const (
DefaultMinRefreshInterval = 30 * time.Minute
)
// Config holds the parsed remote push configuration
type Config struct {
ServerURL url.URL
VersionSince *goversion.Version
VersionUntil *goversion.Version
Interval time.Duration
}
// rawConfig is the JSON wire format fetched from the remote server
type rawConfig struct {
ServerURL string `json:"server_url"`
VersionSince string `json:"version-since"`
VersionUntil string `json:"version-until"`
PeriodMinutes int `json:"period_minutes"`
}
// Manager handles fetching and caching remote push configuration
type Manager struct {
configURL string
minRefreshInterval time.Duration
client *http.Client
mu sync.Mutex
lastConfig *Config
lastFetched time.Time
}
func NewManager(configURL string, minRefreshInterval time.Duration) *Manager {
return &Manager{
configURL: configURL,
minRefreshInterval: minRefreshInterval,
client: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// RefreshIfNeeded fetches new config if the cached one is stale.
// Returns the current config (possibly just fetched) or nil if unavailable.
func (m *Manager) RefreshIfNeeded(ctx context.Context) *Config {
m.mu.Lock()
defer m.mu.Unlock()
if m.isConfigFresh() {
return m.lastConfig
}
fetchedConfig, err := m.fetch(ctx)
m.lastFetched = time.Now()
if err != nil {
log.Warnf("failed to fetch metrics remote config: %v", err)
return m.lastConfig // return cached (may be nil)
}
m.lastConfig = fetchedConfig
log.Tracef("fetched metrics remote config: version-since=%s version-until=%s period=%s",
fetchedConfig.VersionSince, fetchedConfig.VersionUntil, fetchedConfig.Interval)
return fetchedConfig
}
func (m *Manager) isConfigFresh() bool {
if m.lastConfig == nil {
return false
}
return time.Since(m.lastFetched) < m.minRefreshInterval
}
func (m *Manager) fetch(ctx context.Context) (*Config, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, m.configURL, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
resp, err := m.client.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer func() {
if resp.Body != nil {
_ = resp.Body.Close()
}
}()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 4096))
if err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
var raw rawConfig
if err := json.Unmarshal(body, &raw); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if raw.PeriodMinutes <= 0 {
return nil, fmt.Errorf("invalid period_minutes: %d", raw.PeriodMinutes)
}
if raw.ServerURL == "" {
return nil, fmt.Errorf("server_url is required")
}
serverURL, err := url.Parse(raw.ServerURL)
if err != nil {
return nil, fmt.Errorf("parse server_url %q: %w", raw.ServerURL, err)
}
since, err := goversion.NewVersion(raw.VersionSince)
if err != nil {
return nil, fmt.Errorf("parse version-since %q: %w", raw.VersionSince, err)
}
until, err := goversion.NewVersion(raw.VersionUntil)
if err != nil {
return nil, fmt.Errorf("parse version-until %q: %w", raw.VersionUntil, err)
}
return &Config{
ServerURL: *serverURL,
VersionSince: since,
VersionUntil: until,
Interval: time.Duration(raw.PeriodMinutes) * time.Minute,
}, nil
}

Some files were not shown because too many files have changed in this diff Show More