Compare commits

...

51 Commits

Author SHA1 Message Date
mlsmaycon
ea997f4a26 test freebsd 2026-01-03 09:26:03 +01:00
Zoltan Papp
73201c4f3e Add conditional checks for FreeBSD diff file generation in release workflow (#5001) 2025-12-29 12:47:38 +01:00
Carlos Hernandez
33d1761fe8 Apply DNS host config on change only (#4695)
Adds a per-instance uint64 hash to DefaultServer to detect identical merged host DNS configs (including extra domains). applyHostConfig computes and compares the hash, skips applying if unchanged, treats hash errors as a fail-safe (proceed to apply), and updates the stored hash only after successful hashing and apply.
2025-12-29 12:43:57 +01:00
August
aa914a0f26 [docs] Fix broken image link (#4876) 2025-12-24 22:06:35 +05:00
Maycon Santos
ab6a9e85de [misc] Use new sign pipelines 0.1.0 (#4993) 2025-12-24 22:03:14 +05:00
Maycon Santos
d3b123c76d [ci] Add FreeBSD port release job to GitHub Actions (#4916)
adds a job that produces new freebsd release files
2025-12-24 11:22:33 +01:00
Viktor Liu
fc4932a23f [client] Fix Linux UI flickering on state updates (#4886) 2025-12-24 11:06:13 +01:00
Zoltan Papp
b7e98acd1f [client] Android profile switch (#4884)
Expose the profile-manager service for Android. Logout was not part of the manager service implementation. In the future, I recommend moving this logic there.
2025-12-22 22:09:05 +01:00
Maycon Santos
433bc4ead9 [client] lookup for management domains using an additional timeout (#4983)
in some cases iOS and macOS may be locked when looking for management domains during network changes

This change introduce an additional timeout on top of the context call
2025-12-22 20:04:52 +01:00
Zoltan Papp
011cc81678 [client, management] auto-update (#4732) 2025-12-19 19:57:39 +01:00
Zoltan Papp
537151e0f3 Remove redundant lock in peer update logic to avoid deadlock with exported functions (#4953) 2025-12-17 13:55:33 +01:00
Zoltan Papp
a9c28ef723 Add stack trace for bundle (#4957) 2025-12-17 13:49:02 +01:00
Pascal Fischer
c29bb1a289 [management] use xid as request id for logging (#4955) 2025-12-16 14:02:37 +01:00
Zoltan Papp
447cd287f5 [ci] Add local lint setup with pre-push hook to catch issues early (#4925)
* Add local lint setup with pre-push hook to catch issues early

Developers can now catch lint issues before pushing, reducing CI failures
and iteration time. The setup uses golangci-lint locally with the same
configuration as CI.

Setup:
- Run `make setup-hooks` once after cloning
- Pre-push hook automatically lints changed files (~90s)
- Use `make lint` to manually check changed files
- Use `make lint-all` to run full CI-equivalent lint

The Makefile auto-installs golangci-lint to ./bin/ using go install to
match the Go version in go.mod, avoiding version compatibility issues.

---------

Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
2025-12-15 10:34:48 +01:00
Zoltan Papp
5748bdd64e Add health-check agent recognition to avoid error logs (#4917)
Health-check connections now send a properly formatted auth message
with a well-known peer ID instead of immediately closing. The server
recognizes this peer ID and handles the connection gracefully with a
debug log instead of error logs.
2025-12-15 10:28:25 +01:00
Diego Romar
08f31fbcb3 [iOS] Add force relay connection on iOS (#4928)
* [ios] Add a bogus test to check iOS behavior when setting environment variables

* [ios] Revert "Add a bogus test to check iOS behavior when setting environment variables"

This reverts commit 90ca01105a6b0f4471aac07a63fc95e5d4eaef9b.

* [ios] Add EnvList struct to export and import environment variables

* [ios] Add envList parameter to the iOS Client Run method

* [ios] Add some debug logging to exportEnvVarList

* Add "//go:build ios" to client/ios/NetBirdSDK files
2025-12-12 14:29:58 -03:00
Bethuel Mmbaga
932c02eaab [management] Approve all pending peers when peer approval is disabled (#4806) 2025-12-12 18:49:57 +03:00
Pascal Fischer
abcbde26f9 [management] remove context from store methods (#4940) 2025-12-11 21:45:47 +01:00
Pascal Fischer
90e3b8009f [management] Fix sync metrics (#4939) 2025-12-11 20:11:12 +01:00
Pascal Fischer
94d34dc0c5 [management] monitoring updates (#4937) 2025-12-11 18:29:15 +01:00
Pascal Fischer
44851e06fb [management] cleanup logs (#4933) 2025-12-10 19:26:51 +01:00
Viktor Liu
3f4f825ec1 [client] Fix DNS forwarder returning broken records on 4 to 6 mapped IP addresses (#4887) 2025-12-05 17:42:49 +01:00
Viktor Liu
f538e6e9ae [client] Use setsid to avoid the parent process from being killed via HUP by login (#4900) 2025-12-05 03:29:27 +01:00
Maycon Santos
cb6b086164 [client] Reorder subsystem shutdown so peer removal goes first (#4914)
Remove peers before DNS and routes
2025-12-04 21:01:22 +01:00
Zoltan Papp
71b6855e09 [client] Fix engine shutdown deadlock and sync-signal message handling races (#4891)
* Fix engine shutdown deadlock and message handling races

- Release syncMsgMux before waiting for shutdownWg to prevent deadlock
- Check context inside lock in handleSync and receiveSignalEvents
- Prevents nil pointer access when messages arrive during engine stop
2025-12-04 19:51:50 +01:00
Viktor Liu
9bdc4908fb [client] Passthrough all non-NetBird chains to prevent them from dropping NetBird traffic (#4899) 2025-12-04 19:16:38 +01:00
Bethuel Mmbaga
031ab11178 [client] Remove select account prompt (#4912)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-12-04 14:57:29 +01:00
Zoltan Papp
d2e48d4f5e [relay] Use instanceURL instead of Exposed address. (#4905)
Replaces string-based exposed address handling with URL-based InstanceURL() (type url.URL) across relay/server and relay/healthcheck; adds SchemeREL/SchemeRELS constants; updates getInstanceURL to return *url.URL with scheme and TLS validation; adjusts WS dialing and health-check logic to use URL fields.
2025-12-03 18:42:53 +01:00
Bethuel Mmbaga
27dd97c9c4 [management] Add support to disable geolocation service (#4901) 2025-12-03 14:45:59 +03:00
Maycon Santos
e87b4ace11 [client] Add sleep state tracking to handle wakeup/sleep events reliably (#4894)
Adds a new NotifyOSLifecycle RPC and server handler to centralize OS sleep/wake handling, introduces Server.sleepTriggeredDown for coordination, updates client UI to call the new RPC, and adjusts the internal sleep event enum zero-value semantics.
2025-12-03 11:53:39 +01:00
Pascal Fischer
a232cf614c [management] record pat usage metrics (#4888) 2025-12-02 18:31:59 +01:00
Maycon Santos
a293f760af [client] Add conditional peer removal logic during shutdown (#4897) 2025-12-02 16:30:15 +01:00
Pascal Fischer
10e9cf8c62 [management] update management integrations (#4895) 2025-12-02 14:13:01 +01:00
Pascal Fischer
7193bd2da7 [management] Refactor network map controller (#4789) 2025-12-02 12:34:28 +01:00
Bethuel Mmbaga
52948ccd61 [management] Add user created activity event (#4893) 2025-12-02 14:17:59 +03:00
Fahri Shihab
4b77359042 [management] Groups API with name query parameter (#4831) 2025-12-01 16:57:42 +01:00
Zoltan Papp
387d43bcc1 [client, management] Add OAuth select_account prompt support to PKCE flow (#4880)
* Add OAuth select_account prompt support to PKCE flow

Extends LoginFlag enum with select_account options to enable
multi-account selection during authentication. This allows users
to choose which account to use when multiple accounts have active
sessions with the identity provider.

The new flags are backward compatible - existing LoginFlag values
(0=prompt login, 1=max_age=0) retain their original behavior.
2025-12-01 14:25:52 +01:00
Zoltan Papp
e47d815dd2 Fix IsAnotherProcessRunning (#4858)
Compare the exact process name rather than searching for a substring of the full path
2025-12-01 14:16:03 +01:00
shuuri-labs
cb83b7c0d3 [relay] use exposed address for healthcheck TLS validation (#4872)
* fix(relay): use exposed address for healthcheck TLS validation

Healthcheck was using listen address (0.0.0.0) instead of exposed address
(domain name) for certificate validation, causing validation to always fail.

Now correctly uses the exposed address where the TLS certificate is valid,
matching real client connection behavior.

* - store exposedAddress directly in Relay struct instead of parsing on every call
- remove unused parseHostPort() function
- remove unused ListenAddress() method from ServiceChecker interface
- improve error logging with address context

* [relay/healthcheck] Remove QUIC health check logic, update WebSocket validation flow

Refactored health check logic by removing QUIC-specific connection validation and simplifying logic for WebSocket protocol. Adjusted certificate validation flow and improved handling of exposed addresses.

* [relay/healthcheck] Fix certificate validation status during health check

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2025-11-28 21:53:53 +01:00
Zoltan Papp
ddcd182859 [client] Sleep detection on macOS (#4859)
A macOS-specific sleep detection mechanism using IOKit and CoreFoundation via cgo is introduced, with a fallback implementation for unsupported platforms. A public Service wrapper provides an event-driven API translating system sleep/wake events into gRPC calls. The UI client integrates sleep detection to manage connectivity state based on system sleep status.
2025-11-28 17:26:22 +01:00
Maycon Santos
aca0398105 [client] Add excluded port range handling for PKCE flow (#4853) 2025-11-26 16:07:45 +01:00
Viktor Liu
02200d790b [client] Open browser for ssh automatically (#4838) 2025-11-26 16:06:47 +01:00
Bethuel Mmbaga
f31bba87b4 [management] Preserve validator settings on account settings update (#4862) 2025-11-26 17:07:44 +03:00
shuuri-labs
7285fef0f0 feat: Add support for displaying device code (UserCode) on Android TV SSO flow (#4800)
- Modified URLOpener interface to pass userCode alongside URL in login.go
- added ability to force device auth flow
2025-11-25 15:51:16 +01:00
Maycon Santos
20973063d8 [client] Support disable search domain for custom zones (#4826)
Two new boolean flags, SearchDomainDisabled and SkipPTRProcess, are added to CustomZone and its protobuf; they are propagated through the engine to DNS host logic. Host matching now uses SearchDomainDisabled directly, and PTR collection skips zones with SkipPTRProcess; reverse zones are initialized with SearchDomainDisabled: true.
2025-11-24 17:50:08 +01:00
Aziz Hasanain
ba2e9b6d88 [management] Fix SSH JWT issuer derivation for IDPs with path components (#4844) 2025-11-24 12:12:51 +01:00
Viktor Liu
131d7a3694 [client] Make mss clamping optional for nftables (#4843) 2025-11-22 18:57:07 +01:00
Maycon Santos
290fe2d8b9 [client/management/signal/relay] Update go.mod to use Go 1.24.10 and upgrade x/crypto dependencies (#4828)
Upgrade Go toolchain and golang.org/x/* deps to 1.24.10, standardize GitHub Actions to derive Go version from go.mod and adjust checkout ordering, raise WASM size limit to 55 MB, update FreeBSD tarball and gomobile refs, fix a few format-string/logging calls, treat usernames ending with $ as system accounts, and add Windows tests.
2025-11-22 10:10:18 +01:00
Vlad
7fb1a2fe31 [management] removed TestBufferUpdateAccountPeers because it was incorrect (#4839) 2025-11-22 01:23:33 +01:00
Diego Romar
32146e576d [android] allow selection/deselection of network resources on android peers (#4607) 2025-11-21 13:36:33 +01:00
Viktor Liu
1311364397 [client] Increase ssh detection timeout (#4827) 2025-11-20 17:09:22 +01:00
238 changed files with 15212 additions and 2405 deletions

11
.githooks/pre-push Executable file
View File

@@ -0,0 +1,11 @@
#!/bin/bash
echo "Running pre-push hook..."
if ! make lint; then
echo ""
echo "Hint: To push without verification, run:"
echo " git push --no-verify"
exit 1
fi
echo "All checks passed!"

View File

@@ -15,13 +15,14 @@ jobs:
name: "Client / Unit"
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Cache Go modules
uses: actions/cache@v4

View File

@@ -25,7 +25,7 @@ jobs:
release: "14.2"
prepare: |
pkg install -y curl pkgconf xorg
GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"

View File

@@ -30,7 +30,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Get Go environment
@@ -106,15 +106,15 @@ jobs:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -151,15 +151,15 @@ jobs:
needs: [ build-cache ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
id: go-env
run: |
@@ -200,7 +200,7 @@ jobs:
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
-e CONTAINER=${CONTAINER} \
golang:1.23-alpine \
golang:1.24-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
@@ -220,15 +220,15 @@ jobs:
raceFlag: "-race"
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
@@ -270,15 +270,15 @@ jobs:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
@@ -321,15 +321,15 @@ jobs:
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -408,15 +408,16 @@ jobs:
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -497,15 +498,15 @@ jobs:
-p 9090:9090 \
prom/prometheus
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -561,15 +562,15 @@ jobs:
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV

View File

@@ -24,7 +24,7 @@ jobs:
uses: actions/setup-go@v5
id: go
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Get Go environment

View File

@@ -46,7 +46,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'

View File

@@ -20,7 +20,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
@@ -39,7 +39,7 @@ jobs:
- name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
- name: gomobile init
run: gomobile init
- name: build android netbird lib
@@ -56,9 +56,9 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
- name: gomobile init
run: gomobile init
- name: build iOS netbird lib

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.23"
SIGN_PIPE_VER: "v0.1.0"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -19,8 +19,102 @@ concurrency:
cancel-in-progress: true
jobs:
release:
release_freebsd_port:
name: "FreeBSD Port / Build & Test"
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Generate FreeBSD port diff
run: bash release_files/freebsd-port-diff.sh
- name: Generate FreeBSD port issue body
run: bash release_files/freebsd-port-issue-body.sh
- name: Check if diff was generated
id: check_diff
run: |
if ls netbird-*.diff 1> /dev/null 2>&1; then
echo "diff_exists=true" >> $GITHUB_OUTPUT
else
echo "diff_exists=false" >> $GITHUB_OUTPUT
echo "No diff file generated (port may already be up to date)"
fi
- name: Extract version
if: steps.check_diff.outputs.diff_exists == 'true'
id: version
run: |
VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/')
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Generated files for version: $VERSION"
cat netbird-*.diff
- name: Test FreeBSD port
if: steps.check_diff.outputs.diff_exists == 'true'
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "15.0"
prepare: |
# Install required packages
pkg install -y git curl portlint go
# Install Go for building
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"
# Clone ports tree (shallow, only what we need)
git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports
cd /usr/ports
run: |
set -e -x
export PATH=$PATH:/usr/local/go/bin
# Find the diff file
echo "Finding diff file..."
DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1)
echo "Found: $DIFF_FILE"
if [[ -z "$DIFF_FILE" ]]; then
echo "ERROR: Could not find diff file"
find ~ -name "*.diff" -type f 2>/dev/null || true
exit 1
fi
# Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths)
cd /usr/ports
patch -p1 -V none < "$DIFF_FILE"
# Show patched Makefile
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
cd /usr/ports/security/netbird
export BATCH=yes
make package
pkg add ./work/pkg/netbird-*.pkg
netbird version | grep "$version"
echo "FreeBSD port test completed successfully!"
- name: Upload FreeBSD port files
if: steps.check_diff.outputs.diff_exists == 'true'
uses: actions/upload-artifact@v4
with:
name: freebsd-port-files
path: |
./netbird-*-issue.txt
./netbird-*.diff
retention-days: 30
release:
runs-on: ubuntu-latest-m
env:
flags: ""
steps:
@@ -40,7 +134,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23"
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
@@ -136,7 +230,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23"
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
@@ -200,7 +294,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23"
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4

View File

@@ -67,10 +67,13 @@ jobs:
- name: Install curl
run: sudo apt-get install -y curl
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
- name: Cache Go modules
uses: actions/cache@v4
@@ -80,9 +83,6 @@ jobs:
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
- name: Setup MySQL privileges
if: matrix.store == 'mysql'
run: |

View File

@@ -20,7 +20,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
@@ -45,7 +45,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
go-version-file: "go.mod"
- name: Build Wasm client
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
env:
@@ -60,8 +60,8 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 52428800 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
if [ ${SIZE} -gt 57671680 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
exit 1
fi

View File

@@ -136,6 +136,14 @@ checked out and set up:
go mod tidy
```
6. Configure Git hooks for automatic linting:
```bash
make setup-hooks
```
This will configure Git to run linting automatically before each push, helping catch issues early.
### Dev Container Support
If you prefer using a dev container for development, NetBird now includes support for dev containers.

27
Makefile Normal file
View File

@@ -0,0 +1,27 @@
.PHONY: lint lint-all lint-install setup-hooks
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
# Install golangci-lint locally if needed
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)
@echo "Running lint on changed files..."
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
# Lint entire codebase (slow, matches CI)
lint-all: $(GOLANGCI_LINT)
@echo "Running lint on all files..."
@$(GOLANGCI_LINT) run --timeout=12m
# Just install the linter
lint-install: $(GOLANGCI_LINT)
# Setup git hooks for all developers
setup-hooks:
@git config core.hooksPath .githooks
@chmod +x .githooks/pre-push
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"

View File

@@ -1,4 +1,3 @@
<div align="center">
<br/>
<br/>
@@ -113,7 +112,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle">
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
</p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.

View File

@@ -4,10 +4,13 @@ package android
import (
"context"
"fmt"
"os"
"slices"
"sync"
"golang.org/x/exp/maps"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
@@ -16,10 +19,13 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
// ConnectionListener export internal Listener for mobile
@@ -53,7 +59,6 @@ func init() {
// Client struct manage the life circle of background service
type Client struct {
cfgFile string
tunAdapter device.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
@@ -67,12 +72,11 @@ type Client struct {
}
// NewClient instantiate a new Client
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
cfgFile: cfgFile,
deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter,
@@ -84,10 +88,16 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
ConfigPath: cfgFile,
})
if err != nil {
return err
@@ -107,23 +117,29 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
c.ctxCancelLock.Unlock()
auth := NewAuthWithConfig(ctx, cfg)
err = auth.login(urlOpener)
err = auth.login(urlOpener, isAndroidTV)
if err != nil {
return err
}
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
ConfigPath: cfgFile,
})
if err != nil {
return err
@@ -141,8 +157,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// Stop the internal client and free the resources
@@ -156,6 +172,19 @@ func (c *Client) Stop() {
c.ctxCancel()
}
func (c *Client) RenewTun(fd int) error {
if c.connectClient == nil {
return fmt.Errorf("engine not running")
}
e := c.connectClient.Engine()
if e == nil {
return fmt.Errorf("engine not initialized")
}
return e.RenewTun(fd)
}
// SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel)
@@ -177,6 +206,7 @@ func (c *Client) PeersList() *PeerInfoArray {
p.IP,
p.FQDN,
p.ConnStatus.String(),
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
}
peerInfos[n] = pi
}
@@ -201,31 +231,43 @@ func (c *Client) Networks() *NetworkArray {
return nil
}
routeSelector := routeManager.GetRouteSelector()
if routeSelector == nil {
log.Error("could not get route selector")
return nil
}
networkArray := &NetworkArray{
items: make([]Network, 0),
}
resolvedDomains := c.recorder.GetResolvedDomainsStates()
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 {
continue
}
r := routes[0]
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue
}
network := Network{
Name: string(id),
Network: netStr,
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
Name: string(id),
Network: netStr,
Peer: routePeer.FQDN,
Status: routePeer.ConnStatus.String(),
IsSelected: routeSelector.IsSelected(id),
Domains: domains,
}
networkArray.Add(network)
}
@@ -253,6 +295,69 @@ func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
func (c *Client) toggleRoute(command routeCommand) error {
return command.toggleRoute()
}
func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.connectClient
if client == nil {
return nil, fmt.Errorf("not connected")
}
engine := client.Engine()
if engine == nil {
return nil, fmt.Errorf("engine is not running")
}
manager := engine.GetRouteManager()
if manager == nil {
return nil, fmt.Errorf("could not get route manager")
}
return manager, nil
}
func (c *Client) SelectRoute(route string) error {
manager, err := c.getRouteManager()
if err != nil {
return err
}
return c.toggleRoute(selectRouteCommand{route: route, manager: manager})
}
func (c *Client) DeselectRoute(route string) error {
manager, err := c.getRouteManager()
if err != nil {
return err
}
return c.toggleRoute(deselectRouteCommand{route: route, manager: manager})
}
// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain
// with its resolved IP addresses from the provided resolvedDomains map.
func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains {
domains := NetworkDomains{}
for _, d := range route.Domains {
networkDomain := NetworkDomain{
Address: d.SafeString(),
}
if info, exists := resolvedDomains[d]; exists {
for _, prefix := range info.Prefixes {
networkDomain.addResolvedIP(prefix.Addr().String())
}
}
domains.Add(&networkDomain)
}
return domains
}
func exportEnvList(list *EnvList) {
if list == nil {
return

View File

@@ -32,7 +32,7 @@ type ErrListener interface {
// URLOpener it is a callback interface. The Open function will be triggered if
// the backend want to show an url for the user
type URLOpener interface {
Open(string)
Open(url string, userCode string)
OnLoginSuccess()
}
@@ -148,9 +148,9 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
}
// Login try register the client on the server
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) {
go func() {
err := a.login(urlOpener)
err := a.login(urlOpener, isAndroidTV)
if err != nil {
resultListener.OnError(err)
} else {
@@ -159,7 +159,7 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
}()
}
func (a *Auth) login(urlOpener URLOpener) error {
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
var needsLogin bool
// check if we need to generate JWT token
@@ -173,7 +173,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
@@ -199,8 +199,8 @@ func (a *Auth) login(urlOpener URLOpener) error {
return nil
}
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
if err != nil {
return nil, err
}
@@ -210,7 +210,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
}
go urlOpener.Open(flowInfo.VerificationURIComplete)
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)

View File

@@ -0,0 +1,56 @@
//go:build android
package android
import "fmt"
type ResolvedIPs struct {
resolvedIPs []string
}
func (r *ResolvedIPs) Add(ipAddress string) {
r.resolvedIPs = append(r.resolvedIPs, ipAddress)
}
func (r *ResolvedIPs) Get(i int) (string, error) {
if i < 0 || i >= len(r.resolvedIPs) {
return "", fmt.Errorf("%d is out of range", i)
}
return r.resolvedIPs[i], nil
}
func (r *ResolvedIPs) Size() int {
return len(r.resolvedIPs)
}
type NetworkDomain struct {
Address string
resolvedIPs ResolvedIPs
}
func (d *NetworkDomain) addResolvedIP(resolvedIP string) {
d.resolvedIPs.Add(resolvedIP)
}
func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs {
return &d.resolvedIPs
}
type NetworkDomains struct {
domains []*NetworkDomain
}
func (n *NetworkDomains) Add(domain *NetworkDomain) {
n.domains = append(n.domains, domain)
}
func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) {
if i < 0 || i >= len(n.domains) {
return nil, fmt.Errorf("%d is out of range", i)
}
return n.domains[i], nil
}
func (n *NetworkDomains) Size() int {
return len(n.domains)
}

View File

@@ -3,10 +3,16 @@
package android
type Network struct {
Name string
Network string
Peer string
Status string
Name string
Network string
Peer string
Status string
IsSelected bool
Domains NetworkDomains
}
func (n Network) GetNetworkDomains() *NetworkDomains {
return &n.Domains
}
type NetworkArray struct {

View File

@@ -1,3 +1,5 @@
//go:build android
package android
// PeerInfo describe information about the peers. It designed for the UI usage
@@ -5,6 +7,11 @@ type PeerInfo struct {
IP string
FQDN string
ConnStatus string // Todo replace to enum
Routes PeerRoutes
}
func (p *PeerInfo) GetPeerRoutes() *PeerRoutes {
return &p.Routes
}
// PeerInfoArray is a wrapper of []PeerInfo

View File

@@ -0,0 +1,20 @@
//go:build android
package android
import "fmt"
type PeerRoutes struct {
routes []string
}
func (p *PeerRoutes) Get(i int) (string, error) {
if i < 0 || i >= len(p.routes) {
return "", fmt.Errorf("%d is out of range", i)
}
return p.routes[i], nil
}
func (p *PeerRoutes) Size() int {
return len(p.routes)
}

View File

@@ -0,0 +1,10 @@
//go:build android
package android
// PlatformFiles groups paths to files used internally by the engine that can't be created/modified
// at their default locations due to android OS restrictions.
type PlatformFiles interface {
ConfigurationFilePath() string
StateFilePath() string
}

View File

@@ -0,0 +1,257 @@
//go:build android
package android
import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
const (
// Android-specific config filename (different from desktop default.json)
defaultConfigFilename = "netbird.cfg"
// Subdirectory for non-default profiles (must match Java Preferences.java)
profilesSubdir = "profiles"
// Android uses a single user context per app (non-empty username required by ServiceManager)
androidUsername = "android"
)
// Profile represents a profile for gomobile
type Profile struct {
Name string
IsActive bool
}
// ProfileArray wraps profiles for gomobile compatibility
type ProfileArray struct {
items []*Profile
}
// Length returns the number of profiles
func (p *ProfileArray) Length() int {
return len(p.items)
}
// Get returns the profile at index i
func (p *ProfileArray) Get(i int) *Profile {
if i < 0 || i >= len(p.items) {
return nil
}
return p.items[i]
}
/*
/data/data/io.netbird.client/files/ ← configDir parameter
├── netbird.cfg ← Default profile config
├── state.json ← Default profile state
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
└── profiles/ ← Subdirectory for non-default profiles
├── work.json ← Work profile config
├── work.state.json ← Work profile state
├── personal.json ← Personal profile config
└── personal.state.json ← Personal profile state
*/
// ProfileManager manages profiles for Android
// It wraps the internal profilemanager to provide Android-specific behavior
type ProfileManager struct {
configDir string
serviceMgr *profilemanager.ServiceManager
}
// NewProfileManager creates a new profile manager for Android
func NewProfileManager(configDir string) *ProfileManager {
// Set the default config path for Android (stored in root configDir, not profiles/)
defaultConfigPath := filepath.Join(configDir, defaultConfigFilename)
// Set global paths for Android
profilemanager.DefaultConfigPathDir = configDir
profilemanager.DefaultConfigPath = defaultConfigPath
profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json")
// Create ServiceManager with profiles/ subdirectory
// This avoids modifying the global ConfigDirOverride for profile listing
profilesDir := filepath.Join(configDir, profilesSubdir)
serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir)
return &ProfileManager{
configDir: configDir,
serviceMgr: serviceMgr,
}
}
// ListProfiles returns all available profiles
func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
// Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive)
internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername)
if err != nil {
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
// Convert internal profiles to Android Profile type
var profiles []*Profile
for _, p := range internalProfiles {
profiles = append(profiles, &Profile{
Name: p.Name,
IsActive: p.IsActive,
})
}
return &ProfileArray{items: profiles}, nil
}
// GetActiveProfile returns the currently active profile name
func (pm *ProfileManager) GetActiveProfile() (string, error) {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
activeState, err := pm.serviceMgr.GetActiveProfileState()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return activeState.Name, nil
}
// SwitchProfile switches to a different profile
func (pm *ProfileManager) SwitchProfile(profileName string) error {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profileName,
Username: androidUsername,
})
if err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
log.Infof("switched to profile: %s", profileName)
return nil
}
// AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory)
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to add profile: %w", err)
}
log.Infof("created new profile: %s", profileName)
return nil
}
// LogoutProfile logs out from a profile (clears authentication)
func (pm *ProfileManager) LogoutProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
configPath, err := pm.getProfileConfigPath(profileName)
if err != nil {
return err
}
// Check if profile exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("profile '%s' does not exist", profileName)
}
// Read current config using internal profilemanager
config, err := profilemanager.ReadConfig(configPath)
if err != nil {
return fmt.Errorf("failed to read profile config: %w", err)
}
// Clear authentication by removing private key and SSH key
config.PrivateKey = ""
config.SSHKey = ""
// Save config using internal profilemanager
if err := profilemanager.WriteOutConfig(configPath, config); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
log.Infof("logged out from profile: %s", profileName)
return nil
}
// RemoveProfile deletes a profile
func (pm *ProfileManager) RemoveProfile(profileName string) error {
// Use ServiceManager (removes profile from profiles/ directory)
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to remove profile: %w", err)
}
log.Infof("removed profile: %s", profileName)
return nil
}
// getProfileConfigPath returns the config file path for a profile
// This is needed for Android-specific path handling (netbird.cfg for default profile)
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
// Android uses netbird.cfg for default profile instead of default.json
// Default profile is stored in root configDir, not in profiles/
return filepath.Join(pm.configDir, defaultConfigFilename), nil
}
// Non-default profiles are stored in profiles subdirectory
// This matches the Java Preferences.java expectation
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".json"), nil
}
// GetConfigPath returns the config file path for a given profile
// Java should call this instead of constructing paths with Preferences.configFile()
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
return pm.getProfileConfigPath(profileName)
}
// GetStateFilePath returns the state file path for a given profile
// Java should call this instead of constructing paths with Preferences.stateFile()
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
return filepath.Join(pm.configDir, "state.json"), nil
}
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".state.json"), nil
}
// GetActiveConfigPath returns the config file path for the currently active profile
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile()
func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
activeProfile, err := pm.GetActiveProfile()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetConfigPath(activeProfile)
}
// GetActiveStateFilePath returns the state file path for the currently active profile
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile()
func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
activeProfile, err := pm.GetActiveProfile()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetStateFilePath(activeProfile)
}
// sanitizeProfileName removes invalid characters from profile name
func sanitizeProfileName(name string) string {
// Keep only alphanumeric, underscore, and hyphen
var result strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}

View File

@@ -0,0 +1,67 @@
//go:build android
package android
import (
"fmt"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/route"
)
func executeRouteToggle(id string, manager routemanager.Manager,
operationName string,
routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error {
netID := route.NetID(id)
routes := []route.NetID{netID}
log.Debugf("%s with id: %s", operationName, id)
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
log.Debugf("error when %s: %s", operationName, err)
return fmt.Errorf("error %s: %w", operationName, err)
}
manager.TriggerSelection(manager.GetClientRoutes())
return nil
}
type routeCommand interface {
toggleRoute() error
}
type selectRouteCommand struct {
route string
manager routemanager.Manager
}
func (s selectRouteCommand) toggleRoute() error {
routeSelector := s.manager.GetRouteSelector()
if routeSelector == nil {
return fmt.Errorf("no route selector available")
}
routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error {
return routeSelector.SelectRoutes(routes, true, allRoutes)
}
return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation)
}
type deselectRouteCommand struct {
route string
manager routemanager.Manager
}
func (d deselectRouteCommand) toggleRoute() error {
routeSelector := d.manager.GetRouteSelector()
if routeSelector == nil {
return fmt.Errorf("no route selector available")
}
return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes)
}

View File

@@ -4,14 +4,12 @@ import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
@@ -332,7 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
hint = profileState.Email
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint)
if err != nil {
return nil, err
}
@@ -373,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("")
if !noBrowser {
if err := openBrowser(verificationURIComplete); err != nil {
if err := util.OpenBrowser(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
}
}
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -85,6 +85,9 @@ var (
// Execute executes the root command.
func Execute() error {
if isUpdateBinary() {
return updateCmd.Execute()
}
return rootCmd.Execute()
}

View File

@@ -0,0 +1,176 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
bundlePubKeysRootPrivKeyFile string
bundlePubKeysPubKeyFiles []string
bundlePubKeysFile string
createArtifactKeyRootPrivKeyFile string
createArtifactKeyPrivKeyFile string
createArtifactKeyPubKeyFile string
createArtifactKeyExpiration time.Duration
)
var createArtifactKeyCmd = &cobra.Command{
Use: "create-artifact-key",
Short: "Create a new artifact signing key",
Long: `Generate a new artifact signing key pair signed by the root private key.
The artifact key will be used to sign software artifacts/updates.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if createArtifactKeyExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil {
return fmt.Errorf("failed to create artifact key: %w", err)
}
return nil
},
}
var bundlePubKeysCmd = &cobra.Command{
Use: "bundle-pub-keys",
Short: "Bundle multiple artifact public keys into a signed package",
Long: `Bundle one or more artifact public keys into a signed package using the root private key.
This command is typically used to distribute or authorize a set of valid artifact signing keys.`,
RunE: func(cmd *cobra.Command, args []string) error {
if len(bundlePubKeysPubKeyFiles) == 0 {
return fmt.Errorf("at least one --artifact-pub-key-file must be provided")
}
if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil {
return fmt.Errorf("failed to bundle public keys: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createArtifactKeyCmd)
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved")
createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)")
if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(fmt.Errorf("mark expiration as required: %w", err))
}
rootCmd.AddCommand(bundlePubKeysCmd)
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle")
bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)")
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved")
if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil {
panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err))
}
}
func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error {
cmd.Println("Creating new artifact signing key...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration)
if err != nil {
return fmt.Errorf("generate artifact key: %w", err)
}
if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err)
}
if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err)
}
signatureFile := artifactPubKeyFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Artifact key created successfully.\n")
cmd.Printf("%s\n", artifactKey.String())
return nil
}
func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error {
cmd.Println("📦 Bundling public keys into signed package...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles))
for _, pubFile := range artifactPubKeyFiles {
pubPem, err := os.ReadFile(pubFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
pk, err := reposign.ParseArtifactPubKey(pubPem)
if err != nil {
return fmt.Errorf("failed to parse artifact key: %w", err)
}
publicKeys = append(publicKeys, pk)
}
parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys)
if err != nil {
return fmt.Errorf("bundle artifact keys: %w", err)
}
if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil {
return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err)
}
signatureFile := bundlePubKeysFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles))
return nil
}

View File

@@ -0,0 +1,276 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY"
)
var (
signArtifactPrivKeyFile string
signArtifactArtifactFile string
verifyArtifactPubKeyFile string
verifyArtifactFile string
verifyArtifactSignatureFile string
verifyArtifactKeyPubKeyFile string
verifyArtifactKeyRootPubKeyFile string
verifyArtifactKeySignatureFile string
verifyArtifactKeyRevocationFile string
)
var signArtifactCmd = &cobra.Command{
Use: "sign-artifact",
Short: "Sign an artifact using an artifact private key",
Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key.
This command produces a detached signature that can be verified using the corresponding artifact public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil {
return fmt.Errorf("failed to sign artifact: %w", err)
}
return nil
},
}
var verifyArtifactCmd = &cobra.Command{
Use: "verify-artifact",
Short: "Verify an artifact signature using an artifact public key",
Long: `Verify a software artifact signature using the artifact's public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil {
return fmt.Errorf("failed to verify artifact: %w", err)
}
return nil
},
}
var verifyArtifactKeyCmd = &cobra.Command{
Use: "verify-artifact-key",
Short: "Verify an artifact public key was signed by a root key",
Long: `Verify that an artifact public key (or bundle) was properly signed by a root key.
This validates the chain of trust from the root key to the artifact key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil {
return fmt.Errorf("failed to verify artifact key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(signArtifactCmd)
rootCmd.AddCommand(verifyArtifactCmd)
rootCmd.AddCommand(verifyArtifactKeyCmd)
signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey))
signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed")
// artifact-file is required, but artifact-key-file can come from env var
if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file")
if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)")
if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil {
panic(fmt.Errorf("mark root-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
}
func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error {
cmd.Println("🖋️ Signing artifact...")
// Load private key from env var or file
var privKeyPEM []byte
var err error
if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" {
// Use key from environment variable
privKeyPEM = []byte(envKey)
} else if privKeyFile != "" {
// Fall back to file
privKeyPEM, err = os.ReadFile(privKeyFile)
if err != nil {
return fmt.Errorf("read private key file: %w", err)
}
} else {
return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey)
}
privateKey, err := reposign.ParseArtifactKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact private key: %w", err)
}
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
signature, err := reposign.SignData(privateKey, artifactData)
if err != nil {
return fmt.Errorf("sign artifact: %w", err)
}
sigFile := artifactFile + ".sig"
if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", sigFile, err)
}
cmd.Printf("✅ Artifact signed successfully.\n")
cmd.Printf("Signature file: %s\n", sigFile)
return nil
}
func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error {
cmd.Println("🔍 Verifying artifact...")
// Read artifact public key
pubKeyPEM, err := os.ReadFile(pubKeyFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact public key: %w", err)
}
// Read artifact data
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate artifact
if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil {
return fmt.Errorf("artifact verification failed: %w", err)
}
cmd.Println("✅ Artifact signature is valid")
cmd.Printf("Artifact: %s\n", artifactFile)
cmd.Printf("Signed by key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
return nil
}
func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error {
cmd.Println("🔍 Verifying artifact key...")
// Read artifact key data
artifactKeyData, err := os.ReadFile(artifactKeyFile)
if err != nil {
return fmt.Errorf("read artifact key file: %w", err)
}
// Read root public key(s)
rootKeyData, err := os.ReadFile(rootKeyFile)
if err != nil {
return fmt.Errorf("read root key file: %w", err)
}
rootPublicKeys, err := parseRootPublicKeys(rootKeyData)
if err != nil {
return fmt.Errorf("failed to parse root public key(s): %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Read optional revocation list
var revocationList *reposign.RevocationList
if revocationFile != "" {
revData, err := os.ReadFile(revocationFile)
if err != nil {
return fmt.Errorf("read revocation file: %w", err)
}
revocationList, err = reposign.ParseRevocationList(revData)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
}
// Validate artifact key(s)
validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList)
if err != nil {
return fmt.Errorf("artifact key verification failed: %w", err)
}
cmd.Println("✅ Artifact key(s) verified successfully")
cmd.Printf("Signed by root key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys))
for i, key := range validKeys {
cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID)
cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST"))
if !key.Metadata.ExpiresAt.IsZero() {
cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST"))
} else {
cmd.Printf(" Expires: Never\n")
}
}
return nil
}
// parseRootPublicKeys parses a root public key from PEM data
func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) {
key, err := reposign.ParseRootPublicKey(data)
if err != nil {
return nil, err
}
return []reposign.PublicKey{key}, nil
}

21
client/cmd/signer/main.go Normal file
View File

@@ -0,0 +1,21 @@
package main
import (
"os"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "signer",
Short: "A CLI tool for managing cryptographic keys and artifacts",
Long: `signer is a command-line tool that helps you manage
root keys, artifact keys, and revocation lists securely.`,
}
func main() {
if err := rootCmd.Execute(); err != nil {
rootCmd.Println(err)
os.Exit(1)
}
}

View File

@@ -0,0 +1,220 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year
)
var (
keyID string
revocationListFile string
privateRootKeyFile string
publicRootKeyFile string
signatureFile string
expirationDuration time.Duration
)
var createRevocationListCmd = &cobra.Command{
Use: "create-revocation-list",
Short: "Create a new revocation list signed by the private root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile)
},
}
var extendRevocationListCmd = &cobra.Command{
Use: "extend-revocation-list",
Short: "Extend an existing revocation list with a given key ID",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile)
},
}
var verifyRevocationListCmd = &cobra.Command{
Use: "verify-revocation-list",
Short: "Verify a revocation list signature using the public root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile)
},
}
func init() {
rootCmd.AddCommand(createRevocationListCmd)
rootCmd.AddCommand(extendRevocationListCmd)
rootCmd.AddCommand(verifyRevocationListCmd)
createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for")
extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file")
verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file")
verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file")
if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil {
panic(err)
}
}
func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration)
if err != nil {
return fmt.Errorf("failed to create revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list created successfully")
return nil
}
func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
rl, err := reposign.ParseRevocationList(rlBytes)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
kid, err := reposign.ParseKeyID(keyID)
if err != nil {
return fmt.Errorf("invalid key ID: %w", err)
}
newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration)
if err != nil {
return fmt.Errorf("failed to extend revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list extended successfully")
return nil
}
func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error {
// Read revocation list file
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
// Read signature file
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("failed to read signature file: %w", err)
}
// Read public root key file
pubKeyPEM, err := os.ReadFile(publicRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read public root key file: %w", err)
}
// Parse public root key
publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse public root key: %w", err)
}
// Parse signature
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate revocation list
rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature)
if err != nil {
return fmt.Errorf("failed to validate revocation list: %w", err)
}
// Display results
cmd.Println("✅ Revocation list signature is valid")
cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339))
cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339))
cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked))
if len(rl.Revoked) > 0 {
cmd.Println("\nRevoked Keys:")
for keyID, revokedTime := range rl.Revoked {
cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339))
}
}
return nil
}
func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error {
if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil {
return fmt.Errorf("failed to write revocation list file: %w", err)
}
if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil {
return fmt.Errorf("failed to write signature file: %w", err)
}
return nil
}

View File

@@ -0,0 +1,74 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
privKeyFile string
pubKeyFile string
rootExpiration time.Duration
)
var createRootKeyCmd = &cobra.Command{
Use: "create-root-key",
Short: "Create a new root key pair",
Long: `Create a new root key pair and specify an expiration time for it.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
// Validate expiration
if rootExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
// Run main logic
if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil {
return fmt.Errorf("failed to generate root key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createRootKeyCmd)
createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file")
createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file")
createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)")
if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(err)
}
}
func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error {
rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration)
if err != nil {
return fmt.Errorf("generate root key: %w", err)
}
// Write private key
if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", privKeyFile, err)
}
// Write public key
if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err)
}
cmd.Printf("%s\n\n", rk.String())
cmd.Printf("✅ Root key pair generated successfully.\n")
return nil
}

View File

@@ -51,6 +51,7 @@ var (
identityFile string
skipCachedToken bool
requestPTY bool
sshNoBrowser bool
)
var (
@@ -81,6 +82,7 @@ func init() {
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc)
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
@@ -185,6 +187,21 @@ func getEnvOrDefault(flagName, defaultValue string) string {
return defaultValue
}
// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes
func getBoolEnvOrDefault(flagName string, defaultValue bool) bool {
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
if parsed, err := strconv.ParseBool(envValue); err == nil {
return parsed
}
}
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
if parsed, err := strconv.ParseBool(envValue); err == nil {
return parsed
}
}
return defaultValue
}
// resetSSHGlobals sets SSH globals to their default values
func resetSSHGlobals() {
port = sshserver.DefaultSSHPort
@@ -196,6 +213,7 @@ func resetSSHGlobals() {
strictHostKeyChecking = true
knownHostsFile = ""
identityFile = ""
sshNoBrowser = false
}
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
@@ -370,6 +388,7 @@ type sshFlags struct {
KnownHostsFile string
IdentityFile string
SkipCachedToken bool
NoBrowser bool
ConfigPath string
LogLevel string
LocalForwards []string
@@ -381,6 +400,7 @@ type sshFlags struct {
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil)
@@ -401,6 +421,7 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc)
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
@@ -449,6 +470,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile
skipCachedToken = flags.SkipCachedToken
sshNoBrowser = flags.NoBrowser
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
configPath = flags.ConfigPath
@@ -508,6 +530,7 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
DaemonAddr: daemonAddr,
SkipCachedToken: skipCachedToken,
InsecureSkipVerify: !strictHostKeyChecking,
NoBrowser: sshNoBrowser,
})
if err != nil {
@@ -749,7 +772,9 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
if err := util.InitLog(proxyLogLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
@@ -761,7 +786,15 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid port: %s", portStr)
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
// Check env var for browser setting since this command is invoked via SSH ProxyCommand
// where command-line flags cannot be passed. Default is to open browser.
noBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
var browserOpener func(string) error
if !noBrowser {
browserOpener = util.OpenBrowser
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener)
if err != nil {
return fmt.Errorf("create SSH proxy: %w", err)
}
@@ -788,7 +821,8 @@ var sshDetectCmd = &cobra.Command{
}
func sshDetectFn(cmd *cobra.Command, args []string) error {
if err := util.InitLog(logLevel, "console"); err != nil {
detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
if err := util.InitLog(detectLogLevel, "console"); err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
@@ -797,15 +831,21 @@ func sshDetectFn(cmd *cobra.Command, args []string) error {
port, err := strconv.Atoi(portStr)
if err != nil {
log.Debugf("invalid port %q: %v", portStr, err)
os.Exit(detection.ServerTypeRegular.ExitCode())
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout)
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
if err != nil {
log.Debugf("SSH server detection failed: %v", err)
cancel()
os.Exit(detection.ServerTypeRegular.ExitCode())
}
cancel()
os.Exit(serverType.ExitCode())
return nil
}

View File

@@ -15,6 +15,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
clientProto "github.com/netbirdio/netbird/client/proto"
@@ -24,8 +26,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -116,15 +116,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), config)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
if err != nil {
t.Fatal(err)
}

View File

@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)

13
client/cmd/update.go Normal file
View File

@@ -0,0 +1,13 @@
//go:build !windows && !darwin
package cmd
import (
"github.com/spf13/cobra"
)
var updateCmd *cobra.Command
func isUpdateBinary() bool {
return false
}

View File

@@ -0,0 +1,75 @@
//go:build windows || darwin
package cmd
import (
"context"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/util"
)
var (
updateCmd = &cobra.Command{
Use: "update",
Short: "Update the NetBird client application",
RunE: updateFunc,
}
tempDirFlag string
installerFile string
serviceDirFlag string
dryRunFlag bool
)
func init() {
updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir")
updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file")
updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory")
updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes")
}
// isUpdateBinary checks if the current executable is named "update" or "update.exe"
func isUpdateBinary() bool {
// Remove extension for cross-platform compatibility
execPath, err := os.Executable()
if err != nil {
return false
}
baseName := filepath.Base(execPath)
name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
return name == installer.UpdaterBinaryNameWithoutExtension()
}
func updateFunc(cmd *cobra.Command, args []string) error {
if err := setupLogToFile(tempDirFlag); err != nil {
return err
}
log.Infof("updater started: %s", serviceDirFlag)
updater := installer.NewWithDir(tempDirFlag)
if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil {
log.Errorf("failed to update application: %v", err)
return err
}
return nil
}
func setupLogToFile(dir string) error {
logFile := filepath.Join(dir, installer.LogFile)
if _, err := os.Stat(logFile); err == nil {
if err := os.Remove(logFile); err != nil {
log.Errorf("failed to remove existing log file: %v\n", err)
}
}
return util.InitLog(logLevel, util.LogConsole, logFile)
}

View File

@@ -173,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder)
client := internal.NewConnectClient(ctx, c.config, recorder, false)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available

View File

@@ -27,7 +27,11 @@ import (
)
const (
tableNat = "nat"
tableNat = "nat"
tableMangle = "mangle"
tableRaw = "raw"
tableSecurity = "security"
chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
@@ -91,11 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, fmt.Errorf("load filter table: %w", err)
}
log.Debugf("ip filter table not found: %v", err)
}
return r, nil
@@ -175,7 +175,7 @@ func (r *router) removeNatPreroutingRules() error {
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("unable to list tables: %v", err)
return nil, fmt.Errorf("list tables: %w", err)
}
for _, table := range tables {
@@ -187,14 +187,39 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
return nil, errFilterTableNotFound
}
func hookName(hook *nftables.ChainHook) string {
if hook == nil {
return "unknown"
}
switch *hook {
case *nftables.ChainHookForward:
return chainNameForward
case *nftables.ChainHookInput:
return chainNameInput
default:
return fmt.Sprintf("hook(%d)", *hook)
}
}
func familyName(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyIPv4:
return "ip"
case nftables.TableFamilyIPv6:
return "ip6"
case nftables.TableFamilyINet:
return "inet"
default:
return fmt.Sprintf("family(%d)", family)
}
}
func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
@@ -236,9 +261,12 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeFilter,
})
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
r.addPostroutingRules()
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
}
if err := r.addMSSClampingRules(); err != nil {
@@ -250,11 +278,7 @@ func (r *router) createContainers() error {
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
log.Errorf("failed to refresh rules: %s", err)
}
return nil
@@ -695,7 +719,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
}
// addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() error {
func (r *router) addPostroutingRules() {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
@@ -761,8 +785,6 @@ func (r *router) addPostroutingRules() error {
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
@@ -839,7 +861,7 @@ func (r *router) addMSSClampingRules() error {
Exprs: exprsOut,
})
return nil
return r.conn.Flush()
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
@@ -939,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// This method also adds INPUT chain rules to allow traffic to the local interface.
func (r *router) acceptForwardRules() error {
var merr *multierror.Error
if err := r.acceptFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) acceptFilterTableRules() error {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return nil
}
@@ -953,11 +988,11 @@ func (r *router) acceptForwardRules() error {
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
// filter table exists but iptables is not
// iptables is not available but the filter table exists
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables()
return r.acceptFilterRulesNftables(r.filterTable)
}
return r.acceptFilterRulesIptables(ipt)
@@ -968,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
} else {
log.Debugf("added iptables forward rule: %v", rule)
}
@@ -976,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
} else {
log.Debugf("added iptables input rule: %v", inputRule)
}
@@ -996,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
}
func (r *router) acceptFilterRulesNftables() error {
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
// This is used when iptables is not available.
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
intf := ifname(r.wgIface.Name())
forwardChain := &nftables.Chain{
Name: chainNameForward,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
}
r.insertForwardAcceptRules(forwardChain, intf)
inputChain := &nftables.Chain{
Name: chainNameInput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
}
r.insertInputAcceptRule(inputChain, intf)
return r.conn.Flush()
}
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
// It dynamically finds chains at call time to handle chains that may have been created after startup.
func (r *router) acceptExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
intf := ifname(r.wgIface.Name())
for _, chain := range chains {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
continue
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err)
}
return nil
}
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
@@ -1030,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
Data: intf,
},
}
oifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Table: chain.Table,
Chain: chain,
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
}
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
inputRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameInput,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
},
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
@@ -1067,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
UserData: []byte(userDataAcceptInputRule),
}
r.conn.InsertRule(inputRule)
return nil
}
func (r *router) removeAcceptFilterRules() error {
var merr *multierror.Error
if err := r.removeFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeFilterTableRules() error {
if r.filterTable == nil {
return nil
}
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptFilterRulesNftables()
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return r.removeAcceptFilterRulesIptables(ipt)
}
func (r *router) removeAcceptFilterRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
if chain.Table.Name != table.Name {
continue
}
@@ -1100,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
return err
}
}
return r.conn.Flush()
}
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
}
}
}
return nil
}
// removeExternalChainsRules removes our accept rules from all external chains.
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
// ensuring cleanup works even after a crash or if chains changed.
func (r *router) removeExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
for _, chain := range chains {
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
}
}
return r.conn.Flush()
}
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
func (r *router) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
if err != nil {
return fmt.Errorf("get rules: %v", err)
log.Debugf("list chains for family %d: %v", family, err)
continue
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
for _, chain := range allChains {
if r.isExternalChain(chain) {
chains = append(chains, chain)
}
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
return chains
}
func (r *router) isExternalChain(chain *nftables.Chain) bool {
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
return false
}
return nil
// Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false
}
if chain.Type != nftables.ChainTypeFilter {
return false
}
if chain.Hooknum == nil {
return false
}
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
}
func isIptablesTable(name string) bool {
switch name {
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
return true
}
return false
}
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
@@ -1128,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
}
return nberrors.FormatErrorOrNil(merr)
@@ -1196,7 +1358,7 @@ func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf(" unable to list rules: %v", err)
return fmt.Errorf("list rules: %w", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {

View File

@@ -3,6 +3,7 @@
package device
import (
"fmt"
"strings"
log "github.com/sirupsen/logrus"
@@ -19,11 +20,12 @@ import (
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct {
address wgaddr.Address
port int
key string
mtu uint16
iceBind *bind.ICEBind
address wgaddr.Address
port int
key string
mtu uint16
iceBind *bind.ICEBind
// todo: review if we can eliminate the TunAdapter
tunAdapter TunAdapter
disableDNS bool
@@ -32,17 +34,19 @@ type WGTunDevice struct {
filteredDevice *FilteredDevice
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
renewableTun *RenewableTUN
}
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: iceBind,
tunAdapter: tunAdapter,
disableDNS: disableDNS,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: iceBind,
tunAdapter: tunAdapter,
disableDNS: disableDNS,
renewableTun: NewRenewableTUN(),
}
}
@@ -65,14 +69,17 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
return nil, err
}
tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
if err != nil {
_ = unix.Close(fd)
log.Errorf("failed to create Android interface: %s", err)
return nil, err
}
t.renewableTun.AddDevice(unmonitoredTUN)
t.name = name
t.filteredDevice = newDeviceFilter(tunDevice)
t.filteredDevice = newDeviceFilter(t.renewableTun)
log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
@@ -104,6 +111,23 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *WGTunDevice) RenewTun(fd int) error {
if t.device == nil {
return fmt.Errorf("device not initialized")
}
unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd)
if err != nil {
_ = unix.Close(fd)
log.Errorf("failed to renew Android interface: %s", err)
return err
}
t.renewableTun.AddDevice(unmonitoredTUN)
return nil
}
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
// todo implement
return nil

View File

@@ -2,6 +2,13 @@
package device
import "fmt"
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
return t.create()
}
func (t *TunNetstackDevice) RenewTun(fd int) error {
// Doesn't make sense in Android for Netstack.
return fmt.Errorf("this function has not been implemented in Netstack for Android")
}

View File

@@ -0,0 +1,309 @@
//go:build android
package device
import (
"io"
"os"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
)
// closeAwareDevice wraps a tun.Device along with a flag
// indicating whether its Close method was called.
//
// It also redirects tun.Device's Events() to a separate goroutine
// and closes it when Close is called.
//
// The WaitGroup and CloseOnce fields are used to ensure that the
// goroutine is awaited and closed only once.
type closeAwareDevice struct {
isClosed atomic.Bool
tun.Device
closeEventCh chan struct{}
wg sync.WaitGroup
closeOnce sync.Once
}
func newClosableDevice(tunDevice tun.Device) *closeAwareDevice {
return &closeAwareDevice{
Device: tunDevice,
isClosed: atomic.Bool{},
closeEventCh: make(chan struct{}),
}
}
// redirectEvents redirects the Events() method of the underlying tun.Device
// to the given channel (RenewableTUN's events channel).
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
c.wg.Add(1)
go func() {
defer c.wg.Done()
for {
select {
case ev, ok := <-c.Device.Events():
if !ok {
return
}
if ev == tun.EventDown {
continue
}
select {
case out <- ev:
case <-c.closeEventCh:
return
}
case <-c.closeEventCh:
return
}
}
}()
}
// Close calls the underlying Device's Close method
// after setting isClosed to true.
func (c *closeAwareDevice) Close() (err error) {
c.closeOnce.Do(func() {
c.isClosed.Store(true)
close(c.closeEventCh)
err = c.Device.Close()
c.wg.Wait()
})
return err
}
func (c *closeAwareDevice) IsClosed() bool {
return c.isClosed.Load()
}
type RenewableTUN struct {
devices []*closeAwareDevice
mu sync.Mutex
cond *sync.Cond
events chan tun.Event
closed atomic.Bool
}
func NewRenewableTUN() *RenewableTUN {
r := &RenewableTUN{
devices: make([]*closeAwareDevice, 0),
mu: sync.Mutex{},
events: make(chan tun.Event, 16),
}
r.cond = sync.NewCond(&r.mu)
return r
}
func (r *RenewableTUN) File() *os.File {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return nil
}
continue
}
file := dev.File()
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return file
}
}
// Read reads from an underlying tun.Device kept in the r.devices slice.
// If no device is available, it waits for one to be added via AddDevice().
//
// On error, it retries reading from the newest device instead of returning the error
// if the device is closed; if not, it propagates the error.
func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
for {
dev := r.peekLast()
if dev == nil {
// wait until AddDevice() signals a new device via cond.Broadcast()
if !r.waitForDevice() { // returns false if the renewable TUN itself is closed
return 0, io.EOF
}
continue
}
n, err = dev.Read(bufs, sizes, offset)
if err == nil {
return n, nil
}
// swap in progress; retry on the newest instead of returning the error
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err // propagate non-swap error
}
}
// Write writes to underlying tun.Device kept in the r.devices slice.
// If no device is available, it waits for one to be added via AddDevice().
//
// On error, it retries writing to the newest device instead of returning the error
// if the device is closed; if not, it propagates the error.
func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return 0, io.EOF
}
continue
}
n, err := dev.Write(bufs, offset)
if err == nil {
return n, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return n, err
}
}
func (r *RenewableTUN) MTU() (int, error) {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return 0, io.EOF
}
continue
}
mtu, err := dev.MTU()
if err == nil {
return mtu, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return 0, err
}
}
func (r *RenewableTUN) Name() (string, error) {
for {
dev := r.peekLast()
if dev == nil {
if !r.waitForDevice() {
return "", io.EOF
}
continue
}
name, err := dev.Name()
if err == nil {
return name, nil
}
if dev.IsClosed() {
time.Sleep(1 * time.Millisecond)
continue
}
return "", err
}
}
// Events returns a channel that is fed events from the underlying tun.Device's events channel
// once it is added.
func (r *RenewableTUN) Events() <-chan tun.Event {
return r.events
}
func (r *RenewableTUN) Close() error {
// Attempts to set the RenewableTUN closed flag to true.
// If it's already true, returns immediately.
if !r.closed.CompareAndSwap(false, true) {
return nil // already closed: idempotent
}
r.mu.Lock()
devices := r.devices
r.devices = nil
r.cond.Broadcast()
r.mu.Unlock()
var lastErr error
log.Debugf("closing %d devices", len(devices))
for _, device := range devices {
if err := device.Close(); err != nil {
log.Debugf("error closing a device: %v", err)
lastErr = err
}
}
close(r.events)
return lastErr
}
func (r *RenewableTUN) BatchSize() int {
return 1
}
func (r *RenewableTUN) AddDevice(device tun.Device) {
r.mu.Lock()
if r.closed.Load() {
r.mu.Unlock()
_ = device.Close()
return
}
var toClose *closeAwareDevice
if len(r.devices) > 0 {
toClose = r.devices[len(r.devices)-1]
}
cad := newClosableDevice(device)
cad.redirectEvents(r.events)
r.devices = []*closeAwareDevice{cad}
r.cond.Broadcast()
r.mu.Unlock()
if toClose != nil {
if err := toClose.Close(); err != nil {
log.Debugf("error closing last device: %v", err)
}
}
}
func (r *RenewableTUN) waitForDevice() bool {
r.mu.Lock()
defer r.mu.Unlock()
for len(r.devices) == 0 && !r.closed.Load() {
r.cond.Wait()
}
return !r.closed.Load()
}
func (r *RenewableTUN) peekLast() *closeAwareDevice {
r.mu.Lock()
defer r.mu.Unlock()
if len(r.devices) == 0 {
return nil
}
return r.devices[len(r.devices)-1]
}

View File

@@ -21,5 +21,6 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
RenewTun(fd int) error
GetICEBind() device.EndpointManager
}

View File

@@ -24,3 +24,7 @@ func (w *WGIface) Create() error {
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on non mobile")
}
func (w *WGIface) RenewTun(fd int) error {
return fmt.Errorf("this function has not been implemented on non-android")
}

View File

@@ -6,6 +6,7 @@ import (
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
// todo: review does this function really necessary or can we merge it with iOS
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
@@ -22,3 +23,9 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}
func (w *WGIface) RenewTun(fd int) error {
w.mu.Lock()
defer w.mu.Unlock()
return w.tun.RenewTun(fd)
}

View File

@@ -39,3 +39,7 @@ func (w *WGIface) Create() error {
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}
func (w *WGIface) RenewTun(fd int) error {
return fmt.Errorf("this function has not been implemented on this platform")
}

View File

@@ -60,14 +60,19 @@ func (t TokenInfo) GetTokenToUse() string {
return t.AccessToken
}
func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool {
return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient
}
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
//
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV)
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) {
if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) {
return authenticateWithDeviceCodeFlow(ctx, config, hint)
}

View File

@@ -13,6 +13,7 @@ import (
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
@@ -21,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
"github.com/netbirdio/netbird/shared/management/client/common"
)
var _ OAuthFlow = &PKCEAuthorizationFlow{}
@@ -46,9 +48,10 @@ type PKCEAuthorizationFlow struct {
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
// find the first available redirect URL
excludedRanges := getSystemExcludedPortRanges()
for _, redirectURL := range config.RedirectURLs {
if !isRedirectURLPortUsed(redirectURL) {
if !isRedirectURLPortUsed(redirectURL, excludedRanges) {
availableRedirectURL = redirectURL
break
}
@@ -102,10 +105,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
}
if !p.providerConfig.DisablePromptLogin {
if p.providerConfig.LoginFlag.IsPromptLogin() {
switch p.providerConfig.LoginFlag {
case common.LoginFlagPromptLogin:
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
}
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
case common.LoginFlagMaxAge0:
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
}
@@ -282,15 +285,22 @@ func createCodeChallenge(codeVerifier string) string {
return base64.RawURLEncoding.EncodeToString(sha2[:])
}
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
func isRedirectURLPortUsed(redirectURL string) bool {
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows.
func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool {
parsedURL, err := url.Parse(redirectURL)
if err != nil {
log.Errorf("failed to parse redirect URL: %v", err)
return true
}
addr := fmt.Sprintf(":%s", parsedURL.Port())
port := parsedURL.Port()
if isPortInExcludedRange(port, excludedRanges) {
log.Warnf("port %s is in Windows excluded port range, skipping", port)
return true
}
addr := fmt.Sprintf(":%s", port)
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
if err != nil {
return false
@@ -304,6 +314,33 @@ func isRedirectURLPortUsed(redirectURL string) bool {
return true
}
// excludedPortRange represents a range of excluded ports.
type excludedPortRange struct {
start int
end int
}
// isPortInExcludedRange checks if the given port is in any of the excluded ranges.
func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool {
if len(excludedRanges) == 0 {
return false
}
portNum, err := strconv.Atoi(port)
if err != nil {
log.Debugf("invalid port number %s: %v", port, err)
return false
}
for _, r := range excludedRanges {
if portNum >= r.start && portNum <= r.end {
return true
}
}
return false
}
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
if err != nil {

View File

@@ -0,0 +1,8 @@
//go:build !windows
package auth
// getSystemExcludedPortRanges returns nil on non-Windows platforms.
func getSystemExcludedPortRanges() []excludedPortRange {
return nil
}

View File

@@ -2,8 +2,11 @@ package auth
import (
"context"
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
@@ -20,22 +23,28 @@ func TestPromptLogin(t *testing.T) {
name string
loginFlag mgm.LoginFlag
disablePromptLogin bool
expect string
expectContains []string
}{
{
name: "Prompt login",
loginFlag: mgm.LoginFlagPrompt,
expect: promptLogin,
name: "Prompt login",
loginFlag: mgm.LoginFlagPromptLogin,
expectContains: []string{promptLogin},
},
{
name: "Max age 0 login",
loginFlag: mgm.LoginFlagMaxAge0,
expect: maxAge0,
name: "Max age 0",
loginFlag: mgm.LoginFlagMaxAge0,
expectContains: []string{maxAge0},
},
{
name: "Disable prompt login",
loginFlag: mgm.LoginFlagPrompt,
loginFlag: mgm.LoginFlagPromptLogin,
disablePromptLogin: true,
expectContains: []string{},
},
{
name: "None flag should not add parameters",
loginFlag: mgm.LoginFlagNone,
expectContains: []string{},
},
}
@@ -50,6 +59,7 @@ func TestPromptLogin(t *testing.T) {
RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true,
LoginFlag: tc.loginFlag,
DisablePromptLogin: tc.disablePromptLogin,
}
pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil {
@@ -60,12 +70,153 @@ func TestPromptLogin(t *testing.T) {
t.Fatalf("Failed to request auth info: %v", err)
}
if !tc.disablePromptLogin {
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
} else {
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
for _, expected := range tc.expectContains {
require.Contains(t, authInfo.VerificationURIComplete, expected)
}
})
}
}
func TestIsPortInExcludedRange(t *testing.T) {
tests := []struct {
name string
port string
excludedRanges []excludedPortRange
expectedBlocked bool
}{
{
name: "Port in excluded range",
port: "8080",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: true,
},
{
name: "Port at start of range",
port: "8000",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: true,
},
{
name: "Port at end of range",
port: "8100",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: true,
},
{
name: "Port before range",
port: "7999",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: false,
},
{
name: "Port after range",
port: "8101",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: false,
},
{
name: "Empty excluded ranges",
port: "8080",
excludedRanges: []excludedPortRange{},
expectedBlocked: false,
},
{
name: "Nil excluded ranges",
port: "8080",
excludedRanges: nil,
expectedBlocked: false,
},
{
name: "Multiple ranges - port in second range",
port: "9050",
excludedRanges: []excludedPortRange{
{start: 8000, end: 8100},
{start: 9000, end: 9100},
},
expectedBlocked: true,
},
{
name: "Multiple ranges - port not in any range",
port: "8500",
excludedRanges: []excludedPortRange{
{start: 8000, end: 8100},
{start: 9000, end: 9100},
},
expectedBlocked: false,
},
{
name: "Invalid port string",
port: "invalid",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: false,
},
{
name: "Empty port string",
port: "",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedBlocked: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isPortInExcludedRange(tt.port, tt.excludedRanges)
assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch")
})
}
}
func TestIsRedirectURLPortUsed(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer func() {
_ = listener.Close()
}()
usedPort := listener.Addr().(*net.TCPAddr).Port
tests := []struct {
name string
redirectURL string
excludedRanges []excludedPortRange
expectedUsed bool
}{
{
name: "Port in excluded range",
redirectURL: "http://127.0.0.1:8080/",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedUsed: true,
},
{
name: "Port actually in use",
redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort),
excludedRanges: nil,
expectedUsed: true,
},
{
name: "Port not in use and not excluded",
redirectURL: "http://127.0.0.1:65432/",
excludedRanges: nil,
expectedUsed: false,
},
{
name: "Invalid URL without port",
redirectURL: "not-a-valid-url",
excludedRanges: nil,
expectedUsed: false,
},
{
name: "Port excluded even if not in use",
redirectURL: "http://127.0.0.1:8050/",
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
expectedUsed: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges)
assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch")
})
}
}

View File

@@ -0,0 +1,86 @@
//go:build windows
package auth
import (
"bufio"
"fmt"
"os/exec"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh.
func getSystemExcludedPortRanges() []excludedPortRange {
ranges, err := getExcludedPortRangesFromNetsh()
if err != nil {
log.Debugf("failed to get Windows excluded port ranges: %v", err)
return nil
}
return ranges
}
// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command.
func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) {
cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp")
output, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("netsh command: %w", err)
}
return parseExcludedPortRanges(string(output))
}
// parseExcludedPortRanges parses the output of the netsh command to extract port ranges.
func parseExcludedPortRanges(output string) ([]excludedPortRange, error) {
var ranges []excludedPortRange
scanner := bufio.NewScanner(strings.NewReader(output))
foundHeader := false
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") {
foundHeader = true
continue
}
if !foundHeader {
continue
}
if strings.Contains(line, "----------") {
continue
}
if line == "" {
continue
}
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
startPort, err := strconv.Atoi(fields[0])
if err != nil {
continue
}
endPort, err := strconv.Atoi(fields[1])
if err != nil {
continue
}
ranges = append(ranges, excludedPortRange{start: startPort, end: endPort})
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("scan output: %w", err)
}
return ranges, nil
}

View File

@@ -0,0 +1,116 @@
//go:build windows
package auth
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
func TestParseExcludedPortRanges(t *testing.T) {
tests := []struct {
name string
netshOutput string
expectedRanges []excludedPortRange
expectError bool
}{
{
name: "Valid netsh output with multiple ranges",
netshOutput: `
Protocol tcp Dynamic Port Range
---------------------------------
Start Port : 49152
Number of Ports : 16384
Protocol tcp Excluded Port Ranges
---------------------------------
Start Port End Port
---------- --------
5357 5357 *
50000 50059 *
`,
expectedRanges: []excludedPortRange{
{start: 5357, end: 5357},
{start: 50000, end: 50059},
},
expectError: false,
},
{
name: "Empty output",
netshOutput: `
Protocol tcp Dynamic Port Range
---------------------------------
Start Port : 49152
Number of Ports : 16384
`,
expectedRanges: nil,
expectError: false,
},
{
name: "Single range",
netshOutput: `
Protocol tcp Excluded Port Ranges
---------------------------------
Start Port End Port
---------- --------
8080 8090
`,
expectedRanges: []excludedPortRange{
{start: 8080, end: 8090},
},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ranges, err := parseExcludedPortRanges(tt.netshOutput)
if tt.expectError {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.expectedRanges, ranges)
}
})
}
}
func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
ranges := getSystemExcludedPortRanges()
t.Logf("Found %d excluded port ranges on this system", len(ranges))
listener1, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer func() {
_ = listener1.Close()
}()
usedPort1 := listener1.Addr().(*net.TCPAddr).Port
availablePort := 65432
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",
Scope: "openid email profile",
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
RedirectURLs: []string{
fmt.Sprintf("http://127.0.0.1:%d/", usedPort1),
fmt.Sprintf("http://127.0.0.1:%d/", availablePort),
},
UseIDToken: true,
}
flow, err := NewPKCEAuthorizationFlow(config)
require.NoError(t, err)
require.NotNil(t, flow)
assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort),
"Should skip port in use and select available port")
}

View File

@@ -24,10 +24,14 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -39,11 +43,13 @@ import (
)
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
doInitialAutoUpdate bool
engine *Engine
engineMutex sync.Mutex
persistSyncResponse bool
}
@@ -52,13 +58,15 @@ func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
doInitialAutoUpdate: doInitalAutoUpdate,
engineMutex: sync.Mutex{},
}
}
@@ -74,6 +82,7 @@ func (c *ConnectClient) RunOnAndroid(
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
stateFilePath string,
) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
@@ -82,6 +91,7 @@ func (c *ConnectClient) RunOnAndroid(
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil)
}
@@ -160,6 +170,33 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return err
}
var path string
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
// On mobile, use the provided state file path directly
if !fileExists(mobileDependency.StateFilePath) {
if err := createFile(mobileDependency.StateFilePath); err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
path = mobileDependency.StateFilePath
} else {
sm := profilemanager.NewServiceManager("")
path = sm.GetStatePath()
}
stateManager := statemanager.New(path)
stateManager.RegisterState(&sshconfig.ShutdownState{})
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
if err == nil {
updateManager.CheckUpdateSuccess(c.ctx)
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
}
}
defer c.statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
@@ -271,15 +308,25 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
// AutoUpdate will be true when the user click on "Connect" menu on the UI
if c.doInitialAutoUpdate {
log.Infof("start engine by ui, run auto-update check")
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
c.doInitialAutoUpdate = false
}
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
@@ -291,12 +338,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
<-engineCtx.Done()
c.engineMutex.Lock()
engine := c.engine
c.engine = nil
c.engineMutex.Unlock()
if engine != nil && engine.wgInterface != nil {
// todo: consider to remove this condition. Is not thread safe.
// We should always call Stop(), but we need to verify that it is idempotent
if engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}

View File

@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
@@ -56,6 +57,7 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
Anonymization Process
@@ -109,6 +111,9 @@ go tool pprof -http=:8088 heap.prof
This will open a web browser tab with the profiling information.
Stack Trace
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
Routes
The routes.txt file contains detailed routing table information in a tabular format:
@@ -327,6 +332,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
if err := g.addSyncResponse(); err != nil {
return fmt.Errorf("add sync response: %w", err)
}
@@ -354,6 +363,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add systemd logs: %v", err)
}
if err := g.addUpdateLogs(); err != nil {
log.Errorf("failed to add updater logs: %v", err)
}
return nil
}
@@ -522,6 +535,18 @@ func (g *BundleGenerator) addProf() (err error) {
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
stackTrace := bytes.NewReader(buf[:n])
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
return fmt.Errorf("add stack trace file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addInterfaces() error {
interfaces, err := net.Interfaces()
if err != nil {
@@ -630,6 +655,29 @@ func (g *BundleGenerator) addStateFile() error {
return nil
}
func (g *BundleGenerator) addUpdateLogs() error {
inst := installer.New()
logFiles := inst.LogFiles()
if len(logFiles) == 0 {
return nil
}
log.Infof("adding updater logs")
for _, logFile := range logFiles {
data, err := os.ReadFile(logFile)
if err != nil {
log.Warnf("failed to read update log file %s: %v", logFile, err)
continue
}
baseName := filepath.Base(logFile)
if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil {
return fmt.Errorf("add update log file %s to zip: %w", baseName, err)
}
}
return nil
}
func (g *BundleGenerator) addCorruptedStateFiles() error {
sm := profilemanager.NewServiceManager("")
pattern := sm.GetStatePath()

View File

@@ -76,6 +76,9 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
if zone.SkipPTRProcess {
continue
}
for _, record := range zone.Records {
if record.Type != int(dns.TypeA) {
continue
@@ -106,8 +109,9 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) {
records := collectPTRRecords(config, network)
reverseZone := nbdns.CustomZone{
Domain: zoneName,
Records: records,
Domain: zoneName,
Records: records,
SearchDomainDisabled: true,
}
config.CustomZones = append(config.CustomZones, reverseZone)

View File

@@ -11,11 +11,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
)
const (
ipv4ReverseZone = ".in-addr.arpa."
ipv6ReverseZone = ".ip6.arpa."
)
type hostManager interface {
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
restoreHostDNS() error
@@ -110,10 +105,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H
}
for _, customZone := range dnsConfig.CustomZones {
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
MatchOnly: matchOnly,
MatchOnly: customZone.SearchDomainDisabled,
})
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
"net/url"
"strings"
"sync"
@@ -26,6 +27,11 @@ type Resolver struct {
mutex sync.RWMutex
}
type ipsResponse struct {
ips []netip.Addr
err error
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
@@ -99,9 +105,9 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
ips, err := lookupIPWithExtraTimeout(ctx, d)
if err != nil {
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
return err
}
var aRecords, aaaaRecords []dns.RR
@@ -159,6 +165,36 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
return nil
}
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
resultChan := make(chan *ipsResponse, 1)
go func() {
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
resultChan <- &ipsResponse{
err: err,
ips: ips,
}
}()
var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
}
if resp.err != nil {
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
}
return resp.ips, nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
if mgmtURL == nil {

View File

@@ -80,6 +80,7 @@ type DefaultServer struct {
updateSerial uint64
previousConfigHash uint64
currentConfig HostDNSConfig
currentConfigHash uint64
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
@@ -207,6 +208,7 @@ func newDefaultServer(
hostsDNSHolder: newHostsDNSHolder(),
hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
}
// register with root zone, handler chain takes care of the routing
@@ -586,8 +588,29 @@ func (s *DefaultServer) applyHostConfig() {
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
if err != nil {
log.Warnf("unable to hash the host dns configuration, will apply config anyway: %s", err)
// Fall through to apply config anyway (fail-safe approach)
} else if s.currentConfigHash == hash {
log.Debugf("not applying host config as there are no changes")
return
}
log.Debugf("applying host config as there are changes")
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err)
return
}
// Only update hash if it was computed successfully and config was applied
if err == nil {
s.currentConfigHash = hash
}
s.registerFallback(config)

View File

@@ -1602,7 +1602,10 @@ func TestExtraDomains(t *testing.T) {
"other.example.com.",
"duplicate.example.com.",
},
applyHostConfigCall: 4,
// Expect 3 calls instead of 4 because when deregistering duplicate.example.com,
// the domain remains in the config (ref count goes from 2 to 1), so the host
// config hash doesn't change and applyDNSConfig is not called.
applyHostConfigCall: 3,
},
{
name: "Config update with new domains after registration",
@@ -1657,7 +1660,10 @@ func TestExtraDomains(t *testing.T) {
expectedMatchOnly: []string{
"extra.example.com.",
},
applyHostConfigCall: 3,
// Expect 2 calls instead of 3 because when deregistering protected.example.com,
// it's removed from extraDomains but still remains in the config (from customZones),
// so the host config hash doesn't change and applyDNSConfig is not called.
applyHostConfigCall: 2,
},
{
name: "Register domain that is part of nameserver group",

View File

@@ -197,7 +197,7 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
timeoutMsg += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warnf(timeoutMsg)
logger.Warn(timeoutMsg)
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {

View File

@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)

View File

@@ -42,14 +42,13 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager"
cProto "github.com/netbirdio/netbird/client/proto"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -73,6 +72,7 @@ const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
disableAutoUpdate = "disabled"
)
var ErrResetConnection = fmt.Errorf("reset connection")
@@ -201,6 +201,9 @@ type Engine struct {
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
// auto-update
updateManager *updatemanager.Manager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
@@ -221,17 +224,7 @@ type localIpUpdater interface {
}
// NewEngine creates a new Connection Engine with probes attached
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
@@ -247,28 +240,12 @@ func NewEngine(
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: statusRecorder,
stateManager: stateManager,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
}
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
if runtime.GOOS == "ios" {
if !fileExists(mobileDep.StateFilePath) {
err := createFile(mobileDep.StateFilePath)
if err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
path = mobileDep.StateFilePath
}
engine.stateManager = statemanager.New(path)
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine
}
@@ -280,7 +257,6 @@ func (e *Engine) Stop() error {
return nil
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.connMgr != nil {
e.connMgr.Close()
@@ -298,9 +274,6 @@ func (e *Engine) Stop() error {
e.cleanupSSHConfig()
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.ingressGatewayMgr != nil {
if err := e.ingressGatewayMgr.Close(); err != nil {
log.Warnf("failed to cleanup forward rules: %v", err)
@@ -308,24 +281,33 @@ func (e *Engine) Stop() error {
e.ingressGatewayMgr = nil
}
e.stopDNSForwarder()
if e.routeManager != nil {
e.routeManager.Stop(e.stateManager)
}
if e.srWatcher != nil {
e.srWatcher.Close()
}
if e.updateManager != nil {
e.updateManager.Stop()
}
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
if err := e.removeAllPeers(); err != nil {
return fmt.Errorf("failed to remove all peers: %s", err)
log.Errorf("failed to remove all peers: %s", err)
}
if e.routeManager != nil {
e.routeManager.Stop(e.stateManager)
}
e.stopDNSForwarder()
// stop/restore DNS after peers are closed but before interface goes down
// so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.cancel != nil {
e.cancel()
}
@@ -337,16 +319,18 @@ func (e *Engine) Stop() error {
e.flowManager.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
defer stateCancel()
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
if err := e.stateManager.Stop(stateCtx); err != nil {
log.Errorf("failed to stop state manager: %v", err)
}
if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
@@ -432,8 +416,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
if err != nil {
return fmt.Errorf("create rosenpass manager: %w", err)
}
err := e.rpManager.Run()
if err != nil {
if err := e.rpManager.Run(); err != nil {
return fmt.Errorf("run rosenpass manager: %w", err)
}
}
@@ -485,6 +468,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
}
if err := e.createFirewall(); err != nil {
e.close()
return err
}
@@ -538,6 +522,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return nil
}
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.handleAutoUpdateVersion(autoUpdateSettings, true)
}
func (e *Engine) createFirewall() error {
if e.config.DisableFirewall {
log.Infof("firewall is disabled")
@@ -746,10 +737,54 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
return nil
}
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
if autoUpdateSettings == nil {
return
}
disabled := autoUpdateSettings.Version == disableAutoUpdate
// Stop and cleanup if disabled
if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop()
e.updateManager = nil
return
}
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
return
}
// Start manager if needed
if e.updateManager == nil {
log.Infof("starting auto-update manager")
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
if err != nil {
return
}
e.updateManager = updateManager
e.updateManager.Start(e.ctx)
}
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
e.updateManager.SetVersion(autoUpdateSettings.Version)
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
@@ -1207,7 +1242,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
for _, zone := range protoDNSConfig.GetCustomZones() {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
Domain: zone.GetDomain(),
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
SkipPTRProcess: zone.GetSkipPTRProcess(),
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{
@@ -1367,6 +1404,11 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
}
conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)
@@ -1831,6 +1873,18 @@ func (e *Engine) GetWgAddr() netip.Addr {
return e.wgInterface.Address().IP
}
func (e *Engine) RenewTun(fd int) error {
e.syncMsgMux.Lock()
wgInterface := e.wgInterface
e.syncMsgMux.Unlock()
if wgInterface == nil {
return fmt.Errorf("wireguard interface not initialized")
}
return wgInterface.RenewTun(fd)
}
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
func (e *Engine) updateDNSForwarder(
enabled bool,

View File

@@ -30,11 +30,12 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
@@ -54,7 +55,6 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -110,6 +110,10 @@ type MockWGIface struct {
LastActivitiesFunc func() map[string]monotime.Time
}
func (m *MockWGIface) RenewTun(_ int) error {
return nil
}
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
return nil
}
@@ -249,6 +253,7 @@ func TestEngine_SSH(t *testing.T) {
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil,
nil,
)
engine.dnsServer = &dns.MockServer{
@@ -410,21 +415,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
relayMgr,
&EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
wgIface := &MockWGIface{
NameFunc: func() string { return "utun102" },
@@ -643,7 +640,7 @@ func TestEngine_Sync(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
@@ -808,7 +805,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
@@ -1010,7 +1007,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
@@ -1536,7 +1533,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
e.ctx = ctx
return e, err
}
@@ -1624,14 +1621,17 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, "", err
}

View File

@@ -20,6 +20,7 @@ import (
type wgIfaceBase interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
RenewTun(fd int) error
IsUserspaceBind() bool
Name() string
Address() wgaddr.Address

View File

@@ -20,7 +20,7 @@ type EndpointUpdater struct {
wgConfig WgConfig
initiator bool
// mu protects updateWireGuardPeer and cancelFunc
// mu protects cancelFunc
mu sync.Mutex
cancelFunc func()
updateWg sync.WaitGroup
@@ -86,11 +86,9 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U
case <-ctx.Done():
return
case <-t.C:
e.mu.Lock()
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
e.mu.Unlock()
}
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/url"
"os"
"os/user"
"path/filepath"
"reflect"
"runtime"
@@ -165,19 +166,26 @@ func getConfigDir() (string, error) {
if ConfigDirOverride != "" {
return ConfigDirOverride, nil
}
configDir, err := os.UserConfigDir()
base, err := baseConfigDir()
if err != nil {
return "", err
}
configDir = filepath.Join(configDir, "netbird")
if _, err := os.Stat(configDir); os.IsNotExist(err) {
if err := os.MkdirAll(configDir, 0755); err != nil {
return "", err
configDir := filepath.Join(base, "netbird")
if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err
}
return configDir, nil
}
func baseConfigDir() (string, error) {
if runtime.GOOS == "darwin" {
if u, err := user.Current(); err == nil && u.HomeDir != "" {
return filepath.Join(u.HomeDir, "Library", "Application Support"), nil
}
}
return configDir, nil
return os.UserConfigDir()
}
func getConfigDirForUser(username string) (string, error) {

View File

@@ -76,6 +76,7 @@ func (a *ActiveProfileState) FilePath() (string, error) {
}
type ServiceManager struct {
profilesDir string // If set, overrides ConfigDirOverride for profile operations
}
func NewServiceManager(defaultConfigPath string) *ServiceManager {
@@ -85,6 +86,17 @@ func NewServiceManager(defaultConfigPath string) *ServiceManager {
return &ServiceManager{}
}
// NewServiceManagerWithProfilesDir creates a ServiceManager with a specific profiles directory
// This allows setting the profiles directory without modifying the global ConfigDirOverride
func NewServiceManagerWithProfilesDir(defaultConfigPath string, profilesDir string) *ServiceManager {
if defaultConfigPath != "" {
DefaultConfigPath = defaultConfigPath
}
return &ServiceManager{
profilesDir: profilesDir,
}
}
func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil {
@@ -240,7 +252,7 @@ func (s *ServiceManager) DefaultProfilePath() string {
}
func (s *ServiceManager) AddProfile(profileName, username string) error {
configDir, err := getConfigDirForUser(username)
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
}
@@ -270,7 +282,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
}
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
configDir, err := getConfigDirForUser(username)
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
}
@@ -302,7 +314,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
}
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
configDir, err := getConfigDirForUser(username)
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
@@ -361,7 +373,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
configDir, err := getConfigDirForUser(activeProf.Username)
configDir, err := s.getConfigDir(activeProf.Username)
if err != nil {
log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err)
return defaultStatePath
@@ -369,3 +381,12 @@ func (s *ServiceManager) GetStatePath() string {
return filepath.Join(configDir, activeProf.Name+".state.json")
}
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
func (s *ServiceManager) getConfigDir(username string) (string, error) {
if s.profilesDir != "" {
return s.profilesDir, nil
}
return getConfigDirForUser(username)
}

View File

@@ -0,0 +1,218 @@
//go:build darwin && !ios
package sleep
/*
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
#include <IOKit/pwr_mgt/IOPMLib.h>
#include <IOKit/IOMessage.h>
#include <CoreFoundation/CoreFoundation.h>
extern void sleepCallbackBridge();
extern void poweredOnCallbackBridge();
extern void suspendedCallbackBridge();
extern void resumedCallbackBridge();
// C global variables for IOKit state
static IONotificationPortRef g_notifyPortRef = NULL;
static io_object_t g_notifierObject = 0;
static io_object_t g_generalInterestNotifier = 0;
static io_connect_t g_rootPort = 0;
static CFRunLoopRef g_runLoop = NULL;
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
switch (messageType) {
case kIOMessageSystemWillSleep:
sleepCallbackBridge();
IOAllowPowerChange(g_rootPort, (long)messageArgument);
break;
case kIOMessageSystemHasPoweredOn:
poweredOnCallbackBridge();
break;
case kIOMessageServiceIsSuspended:
suspendedCallbackBridge();
break;
case kIOMessageServiceIsResumed:
resumedCallbackBridge();
break;
default:
break;
}
}
static void registerNotifications() {
g_rootPort = IORegisterForSystemPower(
NULL,
&g_notifyPortRef,
(IOServiceInterestCallback)sleepCallback,
&g_notifierObject
);
if (g_rootPort == 0) {
return;
}
CFRunLoopAddSource(CFRunLoopGetCurrent(),
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
g_runLoop = CFRunLoopGetCurrent();
CFRunLoopRun();
}
static void unregisterNotifications() {
CFRunLoopRemoveSource(g_runLoop,
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
IODeregisterForSystemPower(&g_notifierObject);
IOServiceClose(g_rootPort);
IONotificationPortDestroy(g_notifyPortRef);
CFRunLoopStop(g_runLoop);
g_notifyPortRef = NULL;
g_notifierObject = 0;
g_rootPort = 0;
g_runLoop = NULL;
}
*/
import "C"
import (
"context"
"fmt"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
var (
serviceRegistry = make(map[*Detector]struct{})
serviceRegistryMu sync.Mutex
)
//export sleepCallbackBridge
func sleepCallbackBridge() {
log.Info("sleepCallbackBridge event triggered")
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
for svc := range serviceRegistry {
svc.triggerCallback(EventTypeSleep)
}
}
//export resumedCallbackBridge
func resumedCallbackBridge() {
log.Info("resumedCallbackBridge event triggered")
}
//export suspendedCallbackBridge
func suspendedCallbackBridge() {
log.Info("suspendedCallbackBridge event triggered")
}
//export poweredOnCallbackBridge
func poweredOnCallbackBridge() {
log.Info("poweredOnCallbackBridge event triggered")
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
for svc := range serviceRegistry {
svc.triggerCallback(EventTypeWakeUp)
}
}
type Detector struct {
callback func(event EventType)
ctx context.Context
cancel context.CancelFunc
}
func NewDetector() (*Detector, error) {
return &Detector{}, nil
}
func (d *Detector) Register(callback func(event EventType)) error {
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
if _, exists := serviceRegistry[d]; exists {
return fmt.Errorf("detector service already registered")
}
d.callback = callback
d.ctx, d.cancel = context.WithCancel(context.Background())
if len(serviceRegistry) > 0 {
serviceRegistry[d] = struct{}{}
return nil
}
serviceRegistry[d] = struct{}{}
// CFRunLoop must run on a single fixed OS thread
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
C.registerNotifications()
}()
log.Info("sleep detection service started on macOS")
return nil
}
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
// and the runloop is stopped and cleaned up.
func (d *Detector) Deregister() error {
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
_, exists := serviceRegistry[d]
if !exists {
return nil
}
// cancel and remove this detector
d.cancel()
delete(serviceRegistry, d)
// If other Detectors still exist, leave IOKit running
if len(serviceRegistry) > 0 {
return nil
}
log.Info("sleep detection service stopping (deregister)")
// Deregister IOKit notifications, stop runloop, and free resources
C.unregisterNotifications()
return nil
}
func (d *Detector) triggerCallback(event EventType) {
doneChan := make(chan struct{})
timeout := time.NewTimer(500 * time.Millisecond)
defer timeout.Stop()
cb := d.callback
go func(callback func(event EventType)) {
log.Info("sleep detection event fired")
callback(event)
close(doneChan)
}(cb)
select {
case <-doneChan:
case <-d.ctx.Done():
case <-timeout.C:
log.Warnf("sleep callback timed out")
}
}

View File

@@ -0,0 +1,9 @@
//go:build !darwin || ios
package sleep
import "fmt"
func NewDetector() (detector, error) {
return nil, fmt.Errorf("sleep not supported on this platform")
}

View File

@@ -0,0 +1,37 @@
package sleep
var (
EventTypeUnknown EventType = 0
EventTypeSleep EventType = 1
EventTypeWakeUp EventType = 2
)
type EventType int
type detector interface {
Register(callback func(eventType EventType)) error
Deregister() error
}
type Service struct {
detector detector
}
func New() (*Service, error) {
d, err := NewDetector()
if err != nil {
return nil, err
}
return &Service{
detector: d,
}, nil
}
func (s *Service) Register(callback func(eventType EventType)) error {
return s.detector.Register(callback)
}
func (s *Service) Deregister() error {
return s.detector.Deregister()
}

View File

@@ -0,0 +1,35 @@
// Package updatemanager provides automatic update management for the NetBird client.
// It monitors for new versions, handles update triggers from management server directives,
// and orchestrates the download and installation of client updates.
//
// # Overview
//
// The update manager operates as a background service that continuously monitors for
// available updates and automatically initiates the update process when conditions are met.
// It integrates with the installer package to perform the actual installation.
//
// # Update Flow
//
// The complete update process follows these steps:
//
// 1. Manager receives update directive via SetVersion() or detects new version
// 2. Manager validates update should proceed (version comparison, rate limiting)
// 3. Manager publishes "updating" event to status recorder
// 4. Manager persists UpdateState to track update attempt
// 5. Manager downloads installer file (.msi or .exe) to temporary directory
// 6. Manager triggers installation via installer.RunInstallation()
// 7. Installer package handles the actual installation process
// 8. On next startup, CheckUpdateSuccess() verifies update completion
// 9. Manager publishes success/failure event to status recorder
// 10. Manager cleans up UpdateState
//
// # State Management
//
// Update state is persisted across restarts to track update attempts:
//
// - PreUpdateVersion: Version before update attempt
// - TargetVersion: Version attempting to update to
//
// This enables verification of successful updates and appropriate user notification
// after the client restarts with the new version.
package updatemanager

View File

@@ -0,0 +1,138 @@
package downloader
import (
"context"
"fmt"
"io"
"net/http"
"os"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
)
const (
userAgent = "NetBird agent installer/%s"
DefaultRetryDelay = 3 * time.Second
)
func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error {
log.Debugf("starting download from %s", url)
out, err := os.Create(dstFile)
if err != nil {
return fmt.Errorf("failed to create destination file %q: %w", dstFile, err)
}
defer func() {
if cerr := out.Close(); cerr != nil {
log.Warnf("error closing file %q: %v", dstFile, cerr)
}
}()
// First attempt
err = downloadToFileOnce(ctx, url, out)
if err == nil {
log.Infof("successfully downloaded file to %s", dstFile)
return nil
}
// If retryDelay is 0, don't retry
if retryDelay == 0 {
return err
}
log.Warnf("download failed, retrying after %v: %v", retryDelay, err)
// Sleep before retry
if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil {
return fmt.Errorf("download cancelled during retry delay: %w", sleepErr)
}
// Truncate file before retry
if err := out.Truncate(0); err != nil {
return fmt.Errorf("failed to truncate file on retry: %w", err)
}
if _, err := out.Seek(0, 0); err != nil {
return fmt.Errorf("failed to seek to beginning of file: %w", err)
}
// Second attempt
if err := downloadToFileOnce(ctx, url, out); err != nil {
return fmt.Errorf("download failed after retry: %w", err)
}
log.Infof("successfully downloaded file to %s", dstFile)
return nil
}
func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
// Add User-Agent header
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to perform HTTP request: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
log.Warnf("error closing response body: %v", cerr)
}
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, limit))
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return data, nil
}
func downloadToFileOnce(ctx context.Context, url string, out *os.File) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("failed to create HTTP request: %w", err)
}
// Add User-Agent header
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to perform HTTP request: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
log.Warnf("error closing response body: %v", cerr)
}
}()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
}
if _, err := io.Copy(out, resp.Body); err != nil {
return fmt.Errorf("failed to write response body to file: %w", err)
}
return nil
}
func sleepWithContext(ctx context.Context, duration time.Duration) error {
select {
case <-time.After(duration):
return nil
case <-ctx.Done():
return ctx.Err()
}
}

View File

@@ -0,0 +1,199 @@
package downloader
import (
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
)
const (
retryDelay = 100 * time.Millisecond
)
func TestDownloadToFile_Success(t *testing.T) {
// Create a test server that responds successfully
content := "test file content"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(content))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file
err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Verify the file content
data, err := os.ReadFile(dstFile)
if err != nil {
t.Fatalf("failed to read downloaded file: %v", err)
}
if string(data) != content {
t.Errorf("expected content %q, got %q", content, string(data))
}
}
func TestDownloadToFile_SuccessAfterRetry(t *testing.T) {
content := "test file content after retry"
var attemptCount atomic.Int32
// Create a test server that fails on first attempt, succeeds on second
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := attemptCount.Add(1)
if attempt == 1 {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(content))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file (should succeed after retry)
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil {
t.Fatalf("expected no error after retry, got: %v", err)
}
// Verify the file content
data, err := os.ReadFile(dstFile)
if err != nil {
t.Fatalf("failed to read downloaded file: %v", err)
}
if string(data) != content {
t.Errorf("expected content %q, got %q", content, string(data))
}
// Verify it took 2 attempts
if attemptCount.Load() != 2 {
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_FailsAfterRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file (should fail after retry)
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil {
t.Fatal("expected error after retry, got nil")
}
// Verify it tried 2 times
if attemptCount.Load() != 2 {
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Create a context that will be cancelled during retry delay
ctx, cancel := context.WithCancel(context.Background())
// Cancel after a short delay (during the retry sleep)
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
// Download the file (should fail due to context cancellation during retry)
err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile)
if err == nil {
t.Fatal("expected error due to context cancellation, got nil")
}
// Should have only made 1 attempt (cancelled during retry delay)
if attemptCount.Load() != 1 {
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_InvalidURL(t *testing.T) {
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile)
if err == nil {
t.Fatal("expected error for invalid URL, got nil")
}
}
func TestDownloadToFile_InvalidDestination(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("test"))
}))
defer server.Close()
// Use an invalid destination path
err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt")
if err == nil {
t.Fatal("expected error for invalid destination, got nil")
}
}
func TestDownloadToFile_NoRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file with retryDelay = 0 (should not retry)
if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil {
t.Fatal("expected error, got nil")
}
// Verify it only made 1 attempt (no retry)
if attemptCount.Load() != 1 {
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
}
}

View File

@@ -0,0 +1,7 @@
//go:build !windows
package installer
func UpdaterBinaryNameWithoutExtension() string {
return updaterBinary
}

View File

@@ -0,0 +1,11 @@
package installer
import (
"path/filepath"
"strings"
)
func UpdaterBinaryNameWithoutExtension() string {
ext := filepath.Ext(updaterBinary)
return strings.TrimSuffix(updaterBinary, ext)
}

View File

@@ -0,0 +1,111 @@
// Package installer provides functionality for managing NetBird application
// updates and installations across Windows, macOS. It handles
// the complete update lifecycle including artifact download, cryptographic verification,
// installation execution, process management, and result reporting.
//
// # Architecture
//
// The installer package uses a two-process architecture to enable self-updates:
//
// 1. Service Process: The main NetBird daemon process that initiates updates
// 2. Updater Process: A detached child process that performs the actual installation
//
// This separation is critical because:
// - The service binary cannot update itself while running
// - The installer (EXE/MSI/PKG) will terminate the service during installation
// - The updater process survives service termination and restarts it after installation
// - Results can be communicated back to the service after it restarts
//
// # Update Flow
//
// Service Process (RunInstallation):
//
// 1. Validates target version format (semver)
// 2. Determines installer type (EXE, MSI, PKG, or Homebrew)
// 3. Downloads installer file from GitHub releases (if applicable)
// 4. Verifies installer signature using reposign package (cryptographic verification in service process before
// launching updater)
// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows)
// 6. Launches updater process with detached mode:
// - --temp-dir: Temporary directory path
// - --service-dir: Service installation directory
// - --installer-file: Path to downloaded installer (if applicable)
// - --dry-run: Optional flag to test without actually installing
// 7. Service process continues running (will be terminated by installer later)
// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion
//
// Updater Process (Setup):
//
// 1. Receives parameters from service via command-line arguments
// 2. Runs installer with appropriate silent/quiet flags:
// - Windows EXE: installer.exe /S
// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log
// - macOS PKG: installer -pkg installer.pkg -target /
// - macOS Homebrew: brew upgrade netbirdio/tap/netbird
// 3. Installer terminates daemon and UI processes
// 4. Installer replaces binaries with new version
// 5. Updater waits for installer to complete
// 6. Updater restarts daemon:
// - Windows: netbird.exe service start
// - macOS/Linux: netbird service start
// 7. Updater restarts UI:
// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser
// - macOS: Uses launchctl asuser to launch NetBird.app for console user
// - Linux: Not implemented (UI typically auto-starts)
// 8. Updater writes result.json with success/error status
// 9. Updater process exits
//
// # Result Communication
//
// The ResultHandler (result.go) manages communication between updater and service:
//
// Result Structure:
//
// type Result struct {
// Success bool // true if installation succeeded
// Error string // error message if Success is false
// ExecutedAt time.Time // when installation completed
// }
//
// Result files are automatically cleaned up after being read.
//
// # File Locations
//
// Temporary Directory (platform-specific):
//
// Windows:
// - Path: %ProgramData%\Netbird\tmp-install
// - Example: C:\ProgramData\Netbird\tmp-install
//
// macOS:
// - Path: /var/lib/netbird/tmp-install
// - Requires root permissions
//
// Files created during installation:
//
// tmp-install/
// installer.log
// updater[.exe] # Copy of service binary
// netbird_installer_*.[exe|msi|pkg] # Downloaded installer
// result.json # Installation result
// msi.log # MSI verbose log (Windows MSI only)
//
// # API Reference
//
// # Cleanup
//
// CleanUpInstallerFiles() removes temporary files after successful installation:
// - Downloaded installer files (*.exe, *.msi, *.pkg)
// - Updater binary copy
// - Does NOT remove result.json (cleaned by ResultHandler after read)
// - Does NOT remove msi.log (kept for debugging)
//
// # Dry-Run Mode
//
// Dry-run mode allows testing the update process without actually installing:
//
// Enable via environment variable:
//
// export NB_AUTO_UPDATE_DRY_RUN=true
// netbird service install-update 0.29.0
package installer

View File

@@ -0,0 +1,50 @@
//go:build !windows && !darwin
package installer
import (
"context"
"fmt"
)
const (
updaterBinary = "updater"
)
type Installer struct {
tempDir string
}
// New used by the service
func New() *Installer {
return &Installer{}
}
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
func NewWithDir(tempDir string) *Installer {
return &Installer{
tempDir: tempDir,
}
}
func (u *Installer) TempDir() string {
return ""
}
func (c *Installer) LogFiles() []string {
return []string{}
}
func (u *Installer) CleanUpInstallerFiles() error {
return nil
}
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error {
return fmt.Errorf("unsupported platform")
}
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) {
return fmt.Errorf("unsupported platform")
}

View File

@@ -0,0 +1,293 @@
//go:build windows || darwin
package installer
import (
"context"
"fmt"
"io"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"github.com/hashicorp/go-multierror"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
type Installer struct {
tempDir string
}
// New used by the service
func New() *Installer {
return &Installer{
tempDir: defaultTempDir,
}
}
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
func NewWithDir(tempDir string) *Installer {
return &Installer{
tempDir: tempDir,
}
}
// RunInstallation starts the updater process to run the installation
// This will run by the original service process
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) {
resultHandler := NewResultHandler(u.tempDir)
defer func() {
if err != nil {
if writeErr := resultHandler.WriteErr(err); writeErr != nil {
log.Errorf("failed to write error result: %v", writeErr)
}
}
}()
if err := validateTargetVersion(targetVersion); err != nil {
return err
}
if err := u.mkTempDir(); err != nil {
return err
}
var installerFile string
// Download files only when not using any third-party store
if installerType := TypeOfInstaller(ctx); installerType.Downloadable() {
log.Infof("download installer")
var err error
installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion)
if err != nil {
log.Errorf("failed to download installer: %v", err)
return err
}
artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL)
if err != nil {
log.Errorf("failed to create artifact verify: %v", err)
return err
}
if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil {
log.Errorf("artifact verification error: %v", err)
return err
}
}
log.Infof("running installer")
updaterPath, err := u.copyUpdater()
if err != nil {
return err
}
// the directory where the service has been installed
workspace, err := getServiceDir()
if err != nil {
return err
}
args := []string{
"--temp-dir", u.tempDir,
"--service-dir", workspace,
}
if isDryRunEnabled() {
args = append(args, "--dry-run=true")
}
if installerFile != "" {
args = append(args, "--installer-file", installerFile)
}
updateCmd := exec.Command(updaterPath, args...)
log.Infof("starting updater process: %s", updateCmd.String())
// Configure the updater to run in a separate session/process group
// so it survives the parent daemon being stopped
setUpdaterProcAttr(updateCmd)
// Start the updater process asynchronously
if err := updateCmd.Start(); err != nil {
return err
}
pid := updateCmd.Process.Pid
log.Infof("updater started with PID %d", pid)
// Release the process so the OS can fully detach it
if err := updateCmd.Process.Release(); err != nil {
log.Warnf("failed to release updater process: %v", err)
}
return nil
}
// CleanUpInstallerFiles
// - the installer file (pkg, exe, msi)
// - the selfcopy updater.exe
func (u *Installer) CleanUpInstallerFiles() error {
// Check if tempDir exists
info, err := os.Stat(u.tempDir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
if !info.IsDir() {
return nil
}
var merr *multierror.Error
if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) {
merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err))
}
entries, err := os.ReadDir(u.tempDir)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
for _, ext := range binaryExtensions {
if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) {
if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err))
}
break
}
}
}
return merr.ErrorOrNil()
}
func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) {
fileURL := urlWithVersionArch(installerType, targetVersion)
// Clean up temp directory on error
var success bool
defer func() {
if !success {
if err := os.RemoveAll(u.tempDir); err != nil {
log.Errorf("error cleaning up temporary directory: %v", err)
}
}
}()
fileName := path.Base(fileURL)
if fileName == "." || fileName == "/" || fileName == "" {
return "", fmt.Errorf("invalid file URL: %s", fileURL)
}
outputFilePath := filepath.Join(u.tempDir, fileName)
if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil {
return "", err
}
success = true
return outputFilePath, nil
}
func (u *Installer) TempDir() string {
return u.tempDir
}
func (u *Installer) mkTempDir() error {
if err := os.MkdirAll(u.tempDir, 0o755); err != nil {
log.Debugf("failed to create tempdir: %s", u.tempDir)
return err
}
return nil
}
func (u *Installer) copyUpdater() (string, error) {
src, err := getServiceBinary()
if err != nil {
return "", fmt.Errorf("failed to get updater binary: %w", err)
}
dst := filepath.Join(u.tempDir, updaterBinary)
if err := copyFile(src, dst); err != nil {
return "", fmt.Errorf("failed to copy updater binary: %w", err)
}
if err := os.Chmod(dst, 0o755); err != nil {
return "", fmt.Errorf("failed to set permissions: %w", err)
}
return dst, nil
}
func validateTargetVersion(targetVersion string) error {
if targetVersion == "" {
return fmt.Errorf("target version cannot be empty")
}
_, err := goversion.NewVersion(targetVersion)
if err != nil {
return fmt.Errorf("invalid target version %q: %w", targetVersion, err)
}
return nil
}
func copyFile(src, dst string) error {
log.Infof("copying %s to %s", src, dst)
in, err := os.Open(src)
if err != nil {
return fmt.Errorf("open source: %w", err)
}
defer func() {
if err := in.Close(); err != nil {
log.Warnf("failed to close source file: %v", err)
}
}()
out, err := os.Create(dst)
if err != nil {
return fmt.Errorf("create destination: %w", err)
}
defer func() {
if err := out.Close(); err != nil {
log.Warnf("failed to close destination file: %v", err)
}
}()
if _, err := io.Copy(out, in); err != nil {
return fmt.Errorf("copy: %w", err)
}
return nil
}
func getServiceDir() (string, error) {
exePath, err := os.Executable()
if err != nil {
return "", err
}
return filepath.Dir(exePath), nil
}
func getServiceBinary() (string, error) {
return os.Executable()
}
func isDryRunEnabled() bool {
return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true")
}

View File

@@ -0,0 +1,11 @@
package installer
import (
"path/filepath"
)
func (u *Installer) LogFiles() []string {
return []string{
filepath.Join(u.tempDir, LogFile),
}
}

View File

@@ -0,0 +1,12 @@
package installer
import (
"path/filepath"
)
func (u *Installer) LogFiles() []string {
return []string{
filepath.Join(u.tempDir, msiLogFile),
filepath.Join(u.tempDir, LogFile),
}
}

View File

@@ -0,0 +1,238 @@
package installer
import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
log "github.com/sirupsen/logrus"
)
const (
daemonName = "netbird"
updaterBinary = "updater"
uiBinary = "/Applications/NetBird.app"
defaultTempDir = "/var/lib/netbird/tmp-install"
pkgDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
)
var (
binaryExtensions = []string{"pkg"}
)
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
resultHandler := NewResultHandler(u.tempDir)
// Always ensure daemon and UI are restarted after setup
defer func() {
log.Infof("write out result")
var err error
if resultErr == nil {
err = resultHandler.WriteSuccess()
} else {
err = resultHandler.WriteErr(resultErr)
}
if err != nil {
log.Errorf("failed to write update result: %v", err)
}
// skip service restart if dry-run mode is enabled
if dryRun {
return
}
log.Infof("starting daemon back")
if err := u.startDaemon(daemonFolder); err != nil {
log.Errorf("failed to start daemon: %v", err)
}
log.Infof("starting UI back")
if err := u.startUIAsUser(); err != nil {
log.Errorf("failed to start UI: %v", err)
}
}()
if dryRun {
time.Sleep(7 * time.Second)
log.Infof("dry-run mode enabled, skipping actual installation")
resultErr = fmt.Errorf("dry-run mode enabled")
return
}
switch TypeOfInstaller(ctx) {
case TypePKG:
resultErr = u.installPkgFile(ctx, installerFile)
case TypeHomebrew:
resultErr = u.updateHomeBrew(ctx)
}
return resultErr
}
func (u *Installer) startDaemon(daemonFolder string) error {
log.Infof("starting netbird service")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
if output, err := cmd.CombinedOutput(); err != nil {
log.Warnf("failed to start netbird service: %v, output: %s", err, string(output))
return err
}
log.Infof("netbird service started successfully")
return nil
}
func (u *Installer) startUIAsUser() error {
log.Infof("starting netbird-ui: %s", uiBinary)
// Get the current console user
cmd := exec.Command("stat", "-f", "%Su", "/dev/console")
output, err := cmd.Output()
if err != nil {
return fmt.Errorf("failed to get console user: %w", err)
}
username := strings.TrimSpace(string(output))
if username == "" || username == "root" {
return fmt.Errorf("no active user session found")
}
log.Infof("starting UI for user: %s", username)
// Get user's UID
userInfo, err := user.Lookup(username)
if err != nil {
return fmt.Errorf("failed to lookup user %s: %w", username, err)
}
// Start the UI process as the console user using launchctl
// This ensures the app runs in the user's context with proper GUI access
launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary)
log.Infof("launchCmd: %s", launchCmd.String())
// Set the user's home directory for proper macOS app behavior
launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir)
log.Infof("set HOME environment variable: %s", userInfo.HomeDir)
if err := launchCmd.Start(); err != nil {
return fmt.Errorf("failed to start UI process: %w", err)
}
// Release the process so it can run independently
if err := launchCmd.Process.Release(); err != nil {
log.Warnf("failed to release UI process: %v", err)
}
log.Infof("netbird-ui started successfully for user %s", username)
return nil
}
func (u *Installer) installPkgFile(ctx context.Context, path string) error {
log.Infof("installing pkg file: %s", path)
// Kill any existing UI processes before installation
// This ensures the postinstall script's "open $APP" will start the new version
u.killUI()
volume := "/"
cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
if err := cmd.Start(); err != nil {
return fmt.Errorf("error running pkg file: %w", err)
}
log.Infof("installer started with PID %d", cmd.Process.Pid)
if err := cmd.Wait(); err != nil {
return fmt.Errorf("error running pkg file: %w", err)
}
log.Infof("pkg file installed successfully")
return nil
}
func (u *Installer) updateHomeBrew(ctx context.Context) error {
log.Infof("updating homebrew")
// Kill any existing UI processes before upgrade
// This ensures the new version will be started after upgrade
u.killUI()
// Homebrew must be run as a non-root user
// To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
// Check both Apple Silicon and Intel Mac paths
brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/"
brewBinPath := "/opt/homebrew/bin/brew"
if _, err := os.Stat(brewTapPath); os.IsNotExist(err) {
// Try Intel Mac path
brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/"
brewBinPath = "/usr/local/bin/brew"
}
fileInfo, err := os.Stat(brewTapPath)
if err != nil {
return fmt.Errorf("error getting homebrew installation path info: %w", err)
}
fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
if !ok {
return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
}
// Get username from UID
brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
if err != nil {
return fmt.Errorf("error looking up brew installer user: %w", err)
}
userName := brewUser.Username
// Get user HOME, required for brew to run correctly
// https://github.com/Homebrew/brew/issues/15833
homeDir := brewUser.HomeDir
// Check if netbird-ui is installed (must run as the brew user, not root)
checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui")
checkUICmd.Env = append(os.Environ(), "HOME="+homeDir)
uiInstalled := checkUICmd.Run() == nil
// Homebrew does not support installing specific versions
// Thus it will always update to latest and ignore targetVersion
upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"}
if uiInstalled {
upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
}
cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...)
cmd.Env = append(os.Environ(), "HOME="+homeDir)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output))
}
log.Infof("homebrew updated successfully")
return nil
}
func (u *Installer) killUI() {
log.Infof("killing existing netbird-ui processes")
cmd := exec.Command("pkill", "-x", "netbird-ui")
if output, err := cmd.CombinedOutput(); err != nil {
// pkill returns exit code 1 if no processes matched, which is fine
log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output))
} else {
log.Infof("netbird-ui processes killed")
}
}
func urlWithVersionArch(_ Type, version string) string {
url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
}

View File

@@ -0,0 +1,213 @@
package installer
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
daemonName = "netbird.exe"
uiName = "netbird-ui.exe"
updaterBinary = "updater.exe"
msiLogFile = "msi.log"
msiDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
exeDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
)
var (
defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install")
// for the cleanup
binaryExtensions = []string{"msi", "exe"}
)
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
resultHandler := NewResultHandler(u.tempDir)
// Always ensure daemon and UI are restarted after setup
defer func() {
log.Infof("starting daemon back")
if err := u.startDaemon(daemonFolder); err != nil {
log.Errorf("failed to start daemon: %v", err)
}
log.Infof("starting UI back")
if err := u.startUIAsUser(daemonFolder); err != nil {
log.Errorf("failed to start UI: %v", err)
}
log.Infof("write out result")
var err error
if resultErr == nil {
err = resultHandler.WriteSuccess()
} else {
err = resultHandler.WriteErr(resultErr)
}
if err != nil {
log.Errorf("failed to write update result: %v", err)
}
}()
if dryRun {
log.Infof("dry-run mode enabled, skipping actual installation")
resultErr = fmt.Errorf("dry-run mode enabled")
return
}
installerType, err := typeByFileExtension(installerFile)
if err != nil {
log.Debugf("%v", err)
resultErr = err
return
}
var cmd *exec.Cmd
switch installerType {
case TypeExe:
log.Infof("run exe installer: %s", installerFile)
cmd = exec.CommandContext(ctx, installerFile, "/S")
default:
installerDir := filepath.Dir(installerFile)
logPath := filepath.Join(installerDir, msiLogFile)
log.Infof("run msi installer: %s", installerFile)
cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath)
}
cmd.Dir = filepath.Dir(installerFile)
if resultErr = cmd.Start(); resultErr != nil {
log.Errorf("error starting installer: %v", resultErr)
return
}
log.Infof("installer started with PID %d", cmd.Process.Pid)
if resultErr = cmd.Wait(); resultErr != nil {
log.Errorf("installer process finished with error: %v", resultErr)
return
}
return nil
}
func (u *Installer) startDaemon(daemonFolder string) error {
log.Infof("starting netbird service")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
if output, err := cmd.CombinedOutput(); err != nil {
log.Debugf("failed to start netbird service: %v, output: %s", err, string(output))
return err
}
log.Infof("netbird service started successfully")
return nil
}
func (u *Installer) startUIAsUser(daemonFolder string) error {
uiPath := filepath.Join(daemonFolder, uiName)
log.Infof("starting netbird-ui: %s", uiPath)
// Get the active console session ID
sessionID := windows.WTSGetActiveConsoleSessionId()
if sessionID == 0xFFFFFFFF {
return fmt.Errorf("no active user session found")
}
// Get the user token for that session
var userToken windows.Token
err := windows.WTSQueryUserToken(sessionID, &userToken)
if err != nil {
return fmt.Errorf("failed to query user token: %w", err)
}
defer func() {
if err := userToken.Close(); err != nil {
log.Warnf("failed to close user token: %v", err)
}
}()
// Duplicate the token to a primary token
var primaryToken windows.Token
err = windows.DuplicateTokenEx(
userToken,
windows.MAXIMUM_ALLOWED,
nil,
windows.SecurityImpersonation,
windows.TokenPrimary,
&primaryToken,
)
if err != nil {
return fmt.Errorf("failed to duplicate token: %w", err)
}
defer func() {
if err := primaryToken.Close(); err != nil {
log.Warnf("failed to close token: %v", err)
}
}()
// Prepare startup info
var si windows.StartupInfo
si.Cb = uint32(unsafe.Sizeof(si))
si.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
var pi windows.ProcessInformation
cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath))
if err != nil {
return fmt.Errorf("failed to convert path to UTF16: %w", err)
}
creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT
err = windows.CreateProcessAsUser(
primaryToken,
nil,
cmdLine,
nil,
nil,
false,
creationFlags,
nil,
nil,
&si,
&pi,
)
if err != nil {
return fmt.Errorf("CreateProcessAsUser failed: %w", err)
}
// Close handles
if err := windows.CloseHandle(pi.Process); err != nil {
log.Warnf("failed to close process handle: %v", err)
}
if err := windows.CloseHandle(pi.Thread); err != nil {
log.Warnf("failed to close thread handle: %v", err)
}
log.Infof("netbird-ui started successfully in session %d", sessionID)
return nil
}
func urlWithVersionArch(it Type, version string) string {
var url string
if it == TypeExe {
url = exeDownloadURL
} else {
url = msiDownloadURL
}
url = strings.ReplaceAll(url, "%version", version)
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
}

View File

@@ -0,0 +1,5 @@
package installer
const (
LogFile = "installer.log"
)

View File

@@ -0,0 +1,15 @@
package installer
import (
"os/exec"
"syscall"
)
// setUpdaterProcAttr configures the updater process to run in a new session,
// making it independent of the parent daemon process. This ensures the updater
// survives when the daemon is stopped during the pkg installation.
func setUpdaterProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
}
}

View File

@@ -0,0 +1,14 @@
package installer
import (
"os/exec"
"syscall"
)
// setUpdaterProcAttr configures the updater process to run detached from the parent,
// making it independent of the parent daemon process.
func setUpdaterProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS
}
}

View File

@@ -0,0 +1,7 @@
//go:build devartifactsign
package installer
const (
DefaultSigningKeysBaseURL = "http://192.168.0.10:9089/signrepo"
)

View File

@@ -0,0 +1,7 @@
//go:build !devartifactsign
package installer
const (
DefaultSigningKeysBaseURL = "https://publickeys.netbird.io/artifact-signatures"
)

View File

@@ -0,0 +1,230 @@
package installer
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus"
)
const (
resultFile = "result.json"
)
type Result struct {
Success bool
Error string
ExecutedAt time.Time
}
// ResultHandler handles reading and writing update results
type ResultHandler struct {
resultFile string
}
// NewResultHandler creates a new communicator with the given directory path
// The result file will be created as "result.json" in the specified directory
func NewResultHandler(installerDir string) *ResultHandler {
// Create it if it doesn't exist
// do not care if already exists
_ = os.MkdirAll(installerDir, 0o700)
rh := &ResultHandler{
resultFile: filepath.Join(installerDir, resultFile),
}
return rh
}
func (rh *ResultHandler) GetErrorResultReason() string {
result, err := rh.tryReadResult()
if err == nil && !result.Success {
return result.Error
}
if err := rh.cleanup(); err != nil {
log.Warnf("failed to cleanup result file: %v", err)
}
return ""
}
func (rh *ResultHandler) WriteSuccess() error {
result := Result{
Success: true,
ExecutedAt: time.Now(),
}
return rh.write(result)
}
func (rh *ResultHandler) WriteErr(errReason error) error {
result := Result{
Success: false,
Error: errReason.Error(),
ExecutedAt: time.Now(),
}
return rh.write(result)
}
func (rh *ResultHandler) Watch(ctx context.Context) (Result, error) {
log.Infof("start watching result: %s", rh.resultFile)
// Check if file already exists (updater finished before we started watching)
if result, err := rh.tryReadResult(); err == nil {
log.Infof("installer result: %v", result)
return result, nil
}
dir := filepath.Dir(rh.resultFile)
if err := rh.waitForDirectory(ctx, dir); err != nil {
return Result{}, err
}
return rh.watchForResultFile(ctx, dir)
}
func (rh *ResultHandler) waitForDirectory(ctx context.Context, dir string) error {
ticker := time.NewTicker(300 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if info, err := os.Stat(dir); err == nil && info.IsDir() {
return nil
}
}
}
}
func (rh *ResultHandler) watchForResultFile(ctx context.Context, dir string) (Result, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Error(err)
return Result{}, err
}
defer func() {
if err := watcher.Close(); err != nil {
log.Warnf("failed to close watcher: %v", err)
}
}()
if err := watcher.Add(dir); err != nil {
return Result{}, fmt.Errorf("failed to watch directory: %v", err)
}
// Check again after setting up watcher to avoid race condition
// (file could have been created between initial check and watcher setup)
if result, err := rh.tryReadResult(); err == nil {
log.Infof("installer result: %v", result)
return result, nil
}
for {
select {
case <-ctx.Done():
return Result{}, ctx.Err()
case event, ok := <-watcher.Events:
if !ok {
return Result{}, errors.New("watcher closed unexpectedly")
}
if result, done := rh.handleWatchEvent(event); done {
return result, nil
}
case err, ok := <-watcher.Errors:
if !ok {
return Result{}, errors.New("watcher closed unexpectedly")
}
return Result{}, fmt.Errorf("watcher error: %w", err)
}
}
}
func (rh *ResultHandler) handleWatchEvent(event fsnotify.Event) (Result, bool) {
if event.Name != rh.resultFile {
return Result{}, false
}
if event.Has(fsnotify.Create) {
result, err := rh.tryReadResult()
if err != nil {
log.Debugf("error while reading result: %v", err)
return result, true
}
log.Infof("installer result: %v", result)
return result, true
}
return Result{}, false
}
// Write writes the update result to a file for the UI to read
func (rh *ResultHandler) write(result Result) error {
log.Infof("write out installer result to: %s", rh.resultFile)
// Ensure directory exists
dir := filepath.Dir(rh.resultFile)
if err := os.MkdirAll(dir, 0o755); err != nil {
log.Errorf("failed to create directory %s: %v", dir, err)
return err
}
data, err := json.Marshal(result)
if err != nil {
return err
}
// Write to a temporary file first, then rename for atomic operation
tmpPath := rh.resultFile + ".tmp"
if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
log.Errorf("failed to create temp file: %s", err)
return err
}
// Atomic rename
if err := os.Rename(tmpPath, rh.resultFile); err != nil {
if cleanupErr := os.Remove(tmpPath); cleanupErr != nil {
log.Warnf("Failed to remove temp result file: %v", err)
}
return err
}
return nil
}
func (rh *ResultHandler) cleanup() error {
err := os.Remove(rh.resultFile)
if err != nil && !os.IsNotExist(err) {
return err
}
log.Debugf("delete installer result file: %s", rh.resultFile)
return nil
}
// tryReadResult attempts to read and validate the result file
func (rh *ResultHandler) tryReadResult() (Result, error) {
data, err := os.ReadFile(rh.resultFile)
if err != nil {
return Result{}, err
}
var result Result
if err := json.Unmarshal(data, &result); err != nil {
return Result{}, fmt.Errorf("invalid result format: %w", err)
}
if err := rh.cleanup(); err != nil {
log.Warnf("failed to cleanup result file: %v", err)
}
return result, nil
}

View File

@@ -0,0 +1,14 @@
package installer
type Type struct {
name string
downloadable bool
}
func (t Type) String() string {
return t.name
}
func (t Type) Downloadable() bool {
return t.downloadable
}

View File

@@ -0,0 +1,22 @@
package installer
import (
"context"
"os/exec"
)
var (
TypeHomebrew = Type{name: "Homebrew", downloadable: false}
TypePKG = Type{name: "pkg", downloadable: true}
)
func TypeOfInstaller(ctx context.Context) Type {
cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client")
_, err := cmd.Output()
if err != nil && cmd.ProcessState.ExitCode() == 1 {
// Not installed using pkg file, thus installed using Homebrew
return TypeHomebrew
}
return TypePKG
}

View File

@@ -0,0 +1,51 @@
package installer
import (
"context"
"fmt"
"strings"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
)
const (
uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
)
var (
TypeExe = Type{name: "EXE", downloadable: true}
TypeMSI = Type{name: "MSI", downloadable: true}
)
func TypeOfInstaller(_ context.Context) Type {
paths := []string{uninstallKeyPath64, uninstallKeyPath32}
for _, path := range paths {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
if err != nil {
continue
}
if err := k.Close(); err != nil {
log.Warnf("Error closing registry key: %v", err)
}
return TypeExe
}
log.Debug("No registry entry found for Netbird, assuming MSI installation")
return TypeMSI
}
func typeByFileExtension(filePath string) (Type, error) {
switch {
case strings.HasSuffix(strings.ToLower(filePath), ".exe"):
return TypeExe, nil
case strings.HasSuffix(strings.ToLower(filePath), ".msi"):
return TypeMSI, nil
default:
return Type{}, fmt.Errorf("unsupported installer type for file: %s", filePath)
}
}

View File

@@ -0,0 +1,374 @@
//go:build windows || darwin
package updatemanager
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
v "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
const (
latestVersion = "latest"
// this version will be ignored
developmentVersion = "development"
)
var errNoUpdateState = errors.New("no update state found")
type UpdateState struct {
PreUpdateVersion string
TargetVersion string
}
func (u UpdateState) Name() string {
return "autoUpdate"
}
type Manager struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
lastTrigger time.Time
mgmUpdateChan chan struct{}
updateChannel chan struct{}
currentVersion string
update UpdateInterface
wg sync.WaitGroup
cancel context.CancelFunc
expectedVersion *v.Version
updateToLatestVersion bool
// updateMutex protect update and expectedVersion fields
updateMutex sync.Mutex
triggerUpdateFn func(context.Context, string) error
}
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
if runtime.GOOS == "darwin" {
isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
if isBrew {
log.Warnf("auto-update disabled on Home Brew installation")
return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
}
}
return newManager(statusRecorder, stateManager)
}
func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
manager := &Manager{
statusRecorder: statusRecorder,
stateManager: stateManager,
mgmUpdateChan: make(chan struct{}, 1),
updateChannel: make(chan struct{}, 1),
currentVersion: version.NetbirdVersion(),
update: version.NewUpdate("nb/client"),
}
manager.triggerUpdateFn = manager.triggerUpdate
stateManager.RegisterState(&UpdateState{})
return manager, nil
}
// CheckUpdateSuccess checks if the update was successful and send a notification.
// It works without to start the update manager.
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
reason := m.lastResultErrReason()
if reason != "" {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
"Auto-update failed",
fmt.Sprintf("Auto-update failed: %s", reason),
nil,
)
}
updateState, err := m.loadAndDeleteUpdateState(ctx)
if err != nil {
if errors.Is(err, errNoUpdateState) {
return
}
log.Errorf("failed to load update state: %v", err)
return
}
log.Debugf("auto-update state loaded, %v", *updateState)
if updateState.TargetVersion == m.currentVersion {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"Auto-update completed",
fmt.Sprintf("Your NetBird Client was auto-updated to version %s", m.currentVersion),
nil,
)
return
}
}
func (m *Manager) Start(ctx context.Context) {
if m.cancel != nil {
log.Errorf("Manager already started")
return
}
m.update.SetDaemonVersion(m.currentVersion)
m.update.SetOnUpdateListener(func() {
select {
case m.updateChannel <- struct{}{}:
default:
}
})
go m.update.StartFetcher()
ctx, cancel := context.WithCancel(ctx)
m.cancel = cancel
m.wg.Add(1)
go m.updateLoop(ctx)
}
func (m *Manager) SetVersion(expectedVersion string) {
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
if m.cancel == nil {
log.Errorf("manager not started")
return
}
m.updateMutex.Lock()
defer m.updateMutex.Unlock()
if expectedVersion == "" {
log.Errorf("empty expected version provided")
m.expectedVersion = nil
m.updateToLatestVersion = false
return
}
if expectedVersion == latestVersion {
m.updateToLatestVersion = true
m.expectedVersion = nil
} else {
expectedSemVer, err := v.NewVersion(expectedVersion)
if err != nil {
log.Errorf("error parsing version: %v", err)
return
}
if m.expectedVersion != nil && m.expectedVersion.Equal(expectedSemVer) {
return
}
m.expectedVersion = expectedSemVer
m.updateToLatestVersion = false
}
select {
case m.mgmUpdateChan <- struct{}{}:
default:
}
}
func (m *Manager) Stop() {
if m.cancel == nil {
return
}
m.cancel()
m.updateMutex.Lock()
if m.update != nil {
m.update.StopWatch()
m.update = nil
}
m.updateMutex.Unlock()
m.wg.Wait()
}
func (m *Manager) onContextCancel() {
if m.cancel == nil {
return
}
m.updateMutex.Lock()
defer m.updateMutex.Unlock()
if m.update != nil {
m.update.StopWatch()
m.update = nil
}
}
func (m *Manager) updateLoop(ctx context.Context) {
defer m.wg.Done()
for {
select {
case <-ctx.Done():
m.onContextCancel()
return
case <-m.mgmUpdateChan:
case <-m.updateChannel:
log.Infof("fetched new version info")
}
m.handleUpdate(ctx)
}
}
func (m *Manager) handleUpdate(ctx context.Context) {
var updateVersion *v.Version
m.updateMutex.Lock()
if m.update == nil {
m.updateMutex.Unlock()
return
}
expectedVersion := m.expectedVersion
useLatest := m.updateToLatestVersion
curLatestVersion := m.update.LatestVersion()
m.updateMutex.Unlock()
switch {
// Resolve "latest" to actual version
case useLatest:
if curLatestVersion == nil {
log.Tracef("latest version not fetched yet")
return
}
updateVersion = curLatestVersion
// Update to specific version
case expectedVersion != nil:
updateVersion = expectedVersion
default:
log.Debugf("no expected version information set")
return
}
log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
if !m.shouldUpdate(updateVersion) {
return
}
m.lastTrigger = time.Now()
log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"Automatically updating client",
"Your client version is older than auto-update version set in Management, updating client now.",
nil,
)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"",
"",
map[string]string{"progress_window": "show", "version": updateVersion.String()},
)
updateState := UpdateState{
PreUpdateVersion: m.currentVersion,
TargetVersion: updateVersion.String(),
}
if err := m.stateManager.UpdateState(updateState); err != nil {
log.Warnf("failed to update state: %v", err)
} else {
if err = m.stateManager.PersistState(ctx); err != nil {
log.Warnf("failed to persist state: %v", err)
}
}
if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
log.Errorf("Error triggering auto-update: %v", err)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
"Auto-update failed",
fmt.Sprintf("Auto-update failed: %v", err),
nil,
)
}
}
// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
// Returns nil if no state exists.
func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, error) {
stateType := &UpdateState{}
m.stateManager.RegisterState(stateType)
if err := m.stateManager.LoadState(stateType); err != nil {
return nil, fmt.Errorf("load state: %w", err)
}
state := m.stateManager.GetState(stateType)
if state == nil {
return nil, errNoUpdateState
}
updateState, ok := state.(*UpdateState)
if !ok {
return nil, fmt.Errorf("failed to cast state to UpdateState")
}
if err := m.stateManager.DeleteState(updateState); err != nil {
return nil, fmt.Errorf("delete state: %w", err)
}
if err := m.stateManager.PersistState(ctx); err != nil {
return nil, fmt.Errorf("persist state: %w", err)
}
return updateState, nil
}
func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
if m.currentVersion == developmentVersion {
log.Debugf("skipping auto-update, running development version")
return false
}
currentVersion, err := v.NewVersion(m.currentVersion)
if err != nil {
log.Errorf("error checking for update, error parsing version `%s`: %v", m.currentVersion, err)
return false
}
if currentVersion.GreaterThanOrEqual(updateVersion) {
log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", m.currentVersion, updateVersion)
return false
}
if time.Since(m.lastTrigger) < 5*time.Minute {
log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
return false
}
return true
}
func (m *Manager) lastResultErrReason() string {
inst := installer.New()
result := installer.NewResultHandler(inst.TempDir())
return result.GetErrorResultReason()
}
func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
inst := installer.New()
return inst.RunInstallation(ctx, targetVersion)
}

View File

@@ -0,0 +1,214 @@
//go:build windows || darwin
package updatemanager
import (
"context"
"fmt"
"path"
"testing"
"time"
v "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type versionUpdateMock struct {
latestVersion *v.Version
onUpdate func()
}
func (v versionUpdateMock) StopWatch() {}
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
return false
}
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
v.onUpdate = updateFn
}
func (v versionUpdateMock) LatestVersion() *v.Version {
return v.latestVersion
}
func (v versionUpdateMock) StartFetcher() {}
func Test_LatestVersion(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
initialLatestVersion *v.Version
latestVersion *v.Version
shouldUpdateInit bool
shouldUpdateLater bool
}{
{
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
daemonVersion: "1.0.0",
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
latestVersion: v.Must(v.NewSemver("1.0.2")),
shouldUpdateInit: true,
shouldUpdateLater: false,
},
{
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
daemonVersion: "1.0.0",
initialLatestVersion: nil,
latestVersion: v.Must(v.NewSemver("1.0.1")),
shouldUpdateInit: false,
shouldUpdateLater: true,
},
}
for idx, c := range testMatrix {
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = mockUpdate
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion("latest")
var triggeredInit bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.initialLatestVersion.String() {
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
}
triggeredInit = true
case <-time.After(10 * time.Millisecond):
triggeredInit = false
}
if triggeredInit != c.shouldUpdateInit {
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
}
mockUpdate.latestVersion = c.latestVersion
mockUpdate.onUpdate()
var triggeredLater bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
}
triggeredLater = true
case <-time.After(10 * time.Millisecond):
triggeredLater = false
}
if triggeredLater != c.shouldUpdateLater {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
}
m.Stop()
}
}
func Test_HandleUpdate(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
latestVersion *v.Version
expectedVersion string
shouldUpdate bool
}{
{
name: "Update to a specific version should update regardless of if latestVersion is available yet",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.56.0",
shouldUpdate: true,
},
{
name: "Update to specific version should not update if version matches",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.55.0",
shouldUpdate: false,
},
{
name: "Update to specific version should not update if current version is newer",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.54.0",
shouldUpdate: false,
},
{
name: "Update to latest version should update if latest is newer",
daemonVersion: "0.55.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: true,
},
{
name: "Update to latest version should not update if latest == current",
daemonVersion: "0.56.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if daemon version is invalid",
daemonVersion: "development",
latestVersion: v.Must(v.NewSemver("1.0.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expecting latest and latest version is unavailable",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expected version is invalid",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "development",
shouldUpdate: false,
},
}
for idx, c := range testMatrix {
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion(c.expectedVersion)
var updateTriggered bool
select {
case targetVersion := <-targetVersionChan:
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
}
updateTriggered = true
case <-time.After(10 * time.Millisecond):
updateTriggered = false
}
if updateTriggered != c.shouldUpdate {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
}
m.Stop()
}
}

View File

@@ -0,0 +1,39 @@
//go:build !windows && !darwin
package updatemanager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Manager is a no-op stub for unsupported platforms
type Manager struct{}
// NewManager returns a no-op manager for unsupported platforms
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
return nil, fmt.Errorf("update manager is not supported on this platform")
}
// CheckUpdateSuccess is a no-op on unsupported platforms
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
// no-op
}
// Start is a no-op on unsupported platforms
func (m *Manager) Start(ctx context.Context) {
// no-op
}
// SetVersion is a no-op on unsupported platforms
func (m *Manager) SetVersion(expectedVersion string) {
// no-op
}
// Stop is a no-op on unsupported platforms
func (m *Manager) Stop() {
// no-op
}

View File

@@ -0,0 +1,302 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"hash"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/blake2s"
)
const (
tagArtifactPrivate = "ARTIFACT PRIVATE KEY"
tagArtifactPublic = "ARTIFACT PUBLIC KEY"
maxArtifactKeySignatureAge = 10 * 365 * 24 * time.Hour
maxArtifactSignatureAge = 10 * 365 * 24 * time.Hour
)
// ArtifactHash wraps a hash.Hash and counts bytes written
type ArtifactHash struct {
hash.Hash
}
// NewArtifactHash returns an initialized ArtifactHash using BLAKE2s
func NewArtifactHash() *ArtifactHash {
h, err := blake2s.New256(nil)
if err != nil {
panic(err) // Should never happen with nil Key
}
return &ArtifactHash{Hash: h}
}
func (ah *ArtifactHash) Write(b []byte) (int, error) {
return ah.Hash.Write(b)
}
// ArtifactKey is a signing Key used to sign artifacts
type ArtifactKey struct {
PrivateKey
}
func (k ArtifactKey) String() string {
return fmt.Sprintf(
"ArtifactKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
k.Metadata.ID,
k.Metadata.CreatedAt.Format(time.RFC3339),
k.Metadata.ExpiresAt.Format(time.RFC3339),
)
}
func GenerateArtifactKey(rootKey *RootKey, expiration time.Duration) (*ArtifactKey, []byte, []byte, []byte, error) {
// Verify root key is still valid
if !rootKey.Metadata.ExpiresAt.IsZero() && time.Now().After(rootKey.Metadata.ExpiresAt) {
return nil, nil, nil, nil, fmt.Errorf("root key has expired on %s", rootKey.Metadata.ExpiresAt.Format(time.RFC3339))
}
now := time.Now()
expirationTime := now.Add(expiration)
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("generate ed25519 key: %w", err)
}
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: now.UTC(),
ExpiresAt: expirationTime.UTC(),
}
ak := &ArtifactKey{
PrivateKey{
Key: priv,
Metadata: metadata,
},
}
// Marshal PrivateKey struct to JSON
privJSON, err := json.Marshal(ak.PrivateKey)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
}
// Marshal PublicKey struct to JSON
pubKey := PublicKey{
Key: pub,
Metadata: metadata,
}
pubJSON, err := json.Marshal(pubKey)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
}
// Encode to PEM with metadata embedded in bytes
privPEM := pem.EncodeToMemory(&pem.Block{
Type: tagArtifactPrivate,
Bytes: privJSON,
})
pubPEM := pem.EncodeToMemory(&pem.Block{
Type: tagArtifactPublic,
Bytes: pubJSON,
})
// Sign the public key with the root key
signature, err := SignArtifactKey(*rootKey, pubPEM)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to sign artifact key: %w", err)
}
return ak, privPEM, pubPEM, signature, nil
}
func ParseArtifactKey(privKeyPEM []byte) (ArtifactKey, error) {
pk, err := parsePrivateKey(privKeyPEM, tagArtifactPrivate)
if err != nil {
return ArtifactKey{}, fmt.Errorf("failed to parse artifact Key: %w", err)
}
return ArtifactKey{pk}, nil
}
func ParseArtifactPubKey(data []byte) (PublicKey, error) {
pk, _, err := parsePublicKey(data, tagArtifactPublic)
return pk, err
}
func BundleArtifactKeys(rootKey *RootKey, keys []PublicKey) ([]byte, []byte, error) {
if len(keys) == 0 {
return nil, nil, errors.New("no keys to bundle")
}
// Create bundle by concatenating PEM-encoded keys
var pubBundle []byte
for _, pk := range keys {
// Marshal PublicKey struct to JSON
pubJSON, err := json.Marshal(pk)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
}
// Encode to PEM
pubPEM := pem.EncodeToMemory(&pem.Block{
Type: tagArtifactPublic,
Bytes: pubJSON,
})
pubBundle = append(pubBundle, pubPEM...)
}
// Sign the entire bundle with the root key
signature, err := SignArtifactKey(*rootKey, pubBundle)
if err != nil {
return nil, nil, fmt.Errorf("failed to sign artifact key bundle: %w", err)
}
return pubBundle, signature, nil
}
func ValidateArtifactKeys(publicRootKeys []PublicKey, data []byte, signature Signature, revocationList *RevocationList) ([]PublicKey, error) {
now := time.Now().UTC()
if signature.Timestamp.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("signature timestamp is in the future: %v", signature.Timestamp)
log.Debugf("artifact signature error: %v", err)
return nil, err
}
if now.Sub(signature.Timestamp) > maxArtifactKeySignatureAge {
err := fmt.Errorf("signature is too old: %v (created %v)", now.Sub(signature.Timestamp), signature.Timestamp)
log.Debugf("artifact signature error: %v", err)
return nil, err
}
// Reconstruct the signed message: artifact_key_data || timestamp
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
if !verifyAny(publicRootKeys, msg, signature.Signature) {
return nil, errors.New("failed to verify signature of artifact keys")
}
pubKeys, err := parsePublicKeyBundle(data, tagArtifactPublic)
if err != nil {
log.Debugf("failed to parse public keys: %s", err)
return nil, err
}
validKeys := make([]PublicKey, 0, len(pubKeys))
for _, pubKey := range pubKeys {
// Filter out expired keys
if !pubKey.Metadata.ExpiresAt.IsZero() && now.After(pubKey.Metadata.ExpiresAt) {
log.Debugf("Key %s is expired at %v (current time %v)",
pubKey.Metadata.ID, pubKey.Metadata.ExpiresAt, now)
continue
}
if revocationList != nil {
if revTime, revoked := revocationList.Revoked[pubKey.Metadata.ID]; revoked {
log.Debugf("Key %s is revoked as of %v (created %v)",
pubKey.Metadata.ID, revTime, pubKey.Metadata.CreatedAt)
continue
}
}
validKeys = append(validKeys, pubKey)
}
if len(validKeys) == 0 {
log.Debugf("no valid public keys found for artifact keys")
return nil, fmt.Errorf("all %d artifact keys are revoked", len(pubKeys))
}
return validKeys, nil
}
func ValidateArtifact(artifactPubKeys []PublicKey, data []byte, signature Signature) error {
// Validate signature timestamp
now := time.Now().UTC()
if signature.Timestamp.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("artifact signature timestamp is in the future: %v", signature.Timestamp)
log.Debugf("failed to verify signature of artifact: %s", err)
return err
}
if now.Sub(signature.Timestamp) > maxArtifactSignatureAge {
return fmt.Errorf("artifact signature is too old: %v (created %v)",
now.Sub(signature.Timestamp), signature.Timestamp)
}
h := NewArtifactHash()
if _, err := h.Write(data); err != nil {
return fmt.Errorf("failed to hash artifact: %w", err)
}
hash := h.Sum(nil)
// Reconstruct the signed message: hash || length || timestamp
msg := make([]byte, 0, len(hash)+8+8)
msg = append(msg, hash...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
// Find matching Key and verify
for _, keyInfo := range artifactPubKeys {
if keyInfo.Metadata.ID == signature.KeyID {
// Check Key expiration
if !keyInfo.Metadata.ExpiresAt.IsZero() &&
signature.Timestamp.After(keyInfo.Metadata.ExpiresAt) {
return fmt.Errorf("signing Key %s expired at %v, signature from %v",
signature.KeyID, keyInfo.Metadata.ExpiresAt, signature.Timestamp)
}
if ed25519.Verify(keyInfo.Key, msg, signature.Signature) {
log.Debugf("artifact verified successfully with Key: %s", signature.KeyID)
return nil
}
return fmt.Errorf("signature verification failed for Key %s", signature.KeyID)
}
}
return fmt.Errorf("no signing Key found with ID %s", signature.KeyID)
}
func SignData(artifactKey ArtifactKey, data []byte) ([]byte, error) {
if len(data) == 0 { // Check happens too late
return nil, fmt.Errorf("artifact length must be positive, got %d", len(data))
}
h := NewArtifactHash()
if _, err := h.Write(data); err != nil {
return nil, fmt.Errorf("failed to write artifact hash: %w", err)
}
timestamp := time.Now().UTC()
if !artifactKey.Metadata.ExpiresAt.IsZero() && timestamp.After(artifactKey.Metadata.ExpiresAt) {
return nil, fmt.Errorf("artifact key expired at %v", artifactKey.Metadata.ExpiresAt)
}
hash := h.Sum(nil)
// Create message: hash || length || timestamp
msg := make([]byte, 0, len(hash)+8+8)
msg = append(msg, hash...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
sig := ed25519.Sign(artifactKey.Key, msg)
bundle := Signature{
Signature: sig,
Timestamp: timestamp,
KeyID: artifactKey.Metadata.ID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
return json.Marshal(bundle)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
-----BEGIN ROOT PUBLIC KEY-----
eyJLZXkiOiJoaGIxdGRDSEZNMFBuQWp1b2w2cXJ1QXRFbWFFSlg1QjFsZUNxWmpn
V1pvPSIsIk1ldGFkYXRhIjp7ImlkIjoiOWE0OTg2NmI2MzE2MjNiNCIsImNyZWF0
ZWRfYXQiOiIyMDI1LTExLTI0VDE3OjE1OjI4LjYyNzE3MzE3MVoiLCJleHBpcmVz
X2F0IjoiMjAzNS0xMS0yMlQxNzoxNToyOC42MjcxNzMxNzFaIn19
-----END ROOT PUBLIC KEY-----

View File

@@ -0,0 +1,6 @@
-----BEGIN ROOT PUBLIC KEY-----
eyJLZXkiOiJyTDByVTN2MEFOZUNmbDZraitiUUd3TE1waU5CaUJLdVBWSnZtQzgr
ZS84PSIsIk1ldGFkYXRhIjp7ImlkIjoiMTBkNjQyZTY2N2FmMDNkNCIsImNyZWF0
ZWRfYXQiOiIyMDI1LTExLTIwVDE3OjI5OjI5LjE4MDk0NjMxNloiLCJleHBpcmVz
X2F0IjoiMjAyNi0xMS0yMFQxNzoyOToyOS4xODA5NDYzMTZaIn19
-----END ROOT PUBLIC KEY-----

View File

@@ -0,0 +1,174 @@
// Package reposign implements a cryptographic signing and verification system
// for NetBird software update artifacts. It provides a hierarchical key
// management system with support for key rotation, revocation, and secure
// artifact distribution.
//
// # Architecture
//
// The package uses a two-tier key hierarchy:
//
// - Root Keys: Long-lived keys that sign artifact keys. These are embedded
// in the client binary and establish the root of trust. Root keys should
// be kept offline and highly secured.
//
// - Artifact Keys: Short-lived keys that sign release artifacts (binaries,
// packages, etc.). These are rotated regularly and can be revoked if
// compromised. Artifact keys are signed by root keys and distributed via
// a public repository.
//
// This separation allows for operational flexibility: artifact keys can be
// rotated frequently without requiring client updates, while root keys remain
// stable and embedded in the software.
//
// # Cryptographic Primitives
//
// The package uses strong, modern cryptographic algorithms:
// - Ed25519: Fast, secure digital signatures (no timing attacks)
// - BLAKE2s-256: Fast cryptographic hash for artifacts
// - SHA-256: Key ID generation
// - JSON: Structured key and signature serialization
// - PEM: Standard key encoding format
//
// # Security Features
//
// Timestamp Binding:
// - All signatures include cryptographically-bound timestamps
// - Prevents replay attacks and enforces signature freshness
// - Clock skew tolerance: 5 minutes
//
// Key Expiration:
// - All keys have expiration times
// - Expired keys are automatically rejected
// - Signing with an expired key fails immediately
//
// Key Revocation:
// - Compromised keys can be revoked via a signed revocation list
// - Revocation list is checked during artifact validation
// - Revoked keys are filtered out before artifact verification
//
// # File Structure
//
// The package expects the following file layout in the key repository:
//
// signrepo/
// artifact-key-pub.pem # Bundle of artifact public keys
// artifact-key-pub.pem.sig # Root signature of the bundle
// revocation-list.json # List of revoked key IDs
// revocation-list.json.sig # Root signature of revocation list
//
// And in the artifacts repository:
//
// releases/
// v0.28.0/
// netbird-linux-amd64
// netbird-linux-amd64.sig # Artifact signature
// netbird-darwin-amd64
// netbird-darwin-amd64.sig
// ...
//
// # Embedded Root Keys
//
// Root public keys are embedded in the client binary at compile time:
// - Production keys: certs/ directory
// - Development keys: certsdev/ directory
//
// The build tag determines which keys are embedded:
// - Production builds: //go:build !devartifactsign
// - Development builds: //go:build devartifactsign
//
// This ensures that development artifacts cannot be verified using production
// keys and vice versa.
//
// # Key Rotation Strategies
//
// Root Key Rotation:
//
// Root keys can be rotated without breaking existing clients by leveraging
// the multi-key verification system. The loadEmbeddedPublicKeys function
// reads ALL files from the certs/ directory and accepts signatures from ANY
// of the embedded root keys.
//
// To rotate root keys:
//
// 1. Generate a new root key pair:
// newRootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
//
// 2. Add the new public key to the certs/ directory as a new file:
// certs/
// root-pub-2024.pem # Old key (keep this!)
// root-pub-2025.pem # New key (add this)
//
// 3. Build new client versions with both keys embedded. The verification
// will accept signatures from either key.
//
// 4. Start signing new artifact keys with the new root key. Old clients
// with only the old root key will reject these, but new clients with
// both keys will accept them.
//
// Each file in certs/ can contain a single key or a bundle of keys (multiple
// PEM blocks). The system will parse all keys from all files and use them
// for verification. This provides maximum flexibility for key management.
//
// Important: Never remove all old root keys at once. Always maintain at least
// one overlapping key between releases to ensure smooth transitions.
//
// Artifact Key Rotation:
//
// Artifact keys should be rotated regularly (e.g., every 90 days) using the
// bundling mechanism. The BundleArtifactKeys function allows multiple artifact
// keys to be bundled together in a single signed package, and ValidateArtifact
// will accept signatures from ANY key in the bundle.
//
// To rotate artifact keys smoothly:
//
// 1. Generate a new artifact key while keeping the old one:
// newKey, newPrivPEM, newPubPEM, newSig, err := GenerateArtifactKey(rootKey, 90 * 24 * time.Hour)
// // Keep oldPubPEM and oldKey available
//
// 2. Create a bundle containing both old and new public keys
//
// 3. Upload the bundle and its signature to the key repository:
// signrepo/artifact-key-pub.pem # Contains both keys
// signrepo/artifact-key-pub.pem.sig # Root signature
//
// 4. Start signing new releases with the NEW key, but keep the bundle
// unchanged. Clients will download the bundle (containing both keys)
// and accept signatures from either key.
//
// Key bundle validation workflow:
// 1. Client downloads artifact-key-pub.pem and artifact-key-pub.pem.sig
// 2. ValidateArtifactKeys verifies the bundle signature with ANY embedded root key
// 3. ValidateArtifactKeys parses all public keys from the bundle
// 4. ValidateArtifactKeys filters out expired or revoked keys
// 5. When verifying an artifact, ValidateArtifact tries each key until one succeeds
//
// This multi-key acceptance model enables overlapping validity periods and
// smooth transitions without client update requirements.
//
// # Best Practices
//
// Root Key Management:
// - Generate root keys offline on an air-gapped machine
// - Store root private keys in hardware security modules (HSM) if possible
// - Use separate root keys for production and development
// - Rotate root keys infrequently (e.g., every 5-10 years)
// - Plan for root key rotation: embed multiple root public keys
//
// Artifact Key Management:
// - Rotate artifact keys regularly (e.g., every 90 days)
// - Use separate artifact keys for different release channels if needed
// - Revoke keys immediately upon suspected compromise
// - Bundle multiple artifact keys to enable smooth rotation
//
// Signing Process:
// - Sign artifacts in a secure CI/CD environment
// - Never commit private keys to version control
// - Use environment variables or secret management for keys
// - Verify signatures immediately after signing
//
// Distribution:
// - Serve keys and revocation lists from a reliable CDN
// - Use HTTPS for all key and artifact downloads
// - Monitor download failures and signature verification failures
// - Keep revocation list up to date
package reposign

View File

@@ -0,0 +1,10 @@
//go:build devartifactsign
package reposign
import "embed"
//go:embed certsdev
var embeddedCerts embed.FS
const embeddedCertsDir = "certsdev"

View File

@@ -0,0 +1,10 @@
//go:build !devartifactsign
package reposign
import "embed"
//go:embed certs
var embeddedCerts embed.FS
const embeddedCertsDir = "certs"

View File

@@ -0,0 +1,171 @@
package reposign
import (
"crypto/ed25519"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"time"
)
const (
maxClockSkew = 5 * time.Minute
)
// KeyID is a unique identifier for a Key (first 8 bytes of SHA-256 of public Key)
type KeyID [8]byte
// computeKeyID generates a unique ID from a public Key
func computeKeyID(pub ed25519.PublicKey) KeyID {
h := sha256.Sum256(pub)
var id KeyID
copy(id[:], h[:8])
return id
}
// MarshalJSON implements json.Marshaler for KeyID
func (k KeyID) MarshalJSON() ([]byte, error) {
return json.Marshal(k.String())
}
// UnmarshalJSON implements json.Unmarshaler for KeyID
func (k *KeyID) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
parsed, err := ParseKeyID(s)
if err != nil {
return err
}
*k = parsed
return nil
}
// ParseKeyID parses a hex string (16 hex chars = 8 bytes) into a KeyID.
func ParseKeyID(s string) (KeyID, error) {
var id KeyID
if len(s) != 16 {
return id, fmt.Errorf("invalid KeyID length: got %d, want 16 hex chars (8 bytes)", len(s))
}
b, err := hex.DecodeString(s)
if err != nil {
return id, fmt.Errorf("failed to decode KeyID: %w", err)
}
copy(id[:], b)
return id, nil
}
func (k KeyID) String() string {
return fmt.Sprintf("%x", k[:])
}
// KeyMetadata contains versioning and lifecycle information for a Key
type KeyMetadata struct {
ID KeyID `json:"id"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at,omitempty"` // Optional expiration
}
// PublicKey wraps a public Key with its Metadata
type PublicKey struct {
Key ed25519.PublicKey
Metadata KeyMetadata
}
func parsePublicKeyBundle(bundle []byte, typeTag string) ([]PublicKey, error) {
var keys []PublicKey
for len(bundle) > 0 {
keyInfo, rest, err := parsePublicKey(bundle, typeTag)
if err != nil {
return nil, err
}
keys = append(keys, keyInfo)
bundle = rest
}
if len(keys) == 0 {
return nil, errors.New("no keys found in bundle")
}
return keys, nil
}
func parsePublicKey(data []byte, typeTag string) (PublicKey, []byte, error) {
b, rest := pem.Decode(data)
if b == nil {
return PublicKey{}, nil, errors.New("failed to decode PEM data")
}
if b.Type != typeTag {
return PublicKey{}, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
}
// Unmarshal JSON-embedded format
var pub PublicKey
if err := json.Unmarshal(b.Bytes, &pub); err != nil {
return PublicKey{}, nil, fmt.Errorf("failed to unmarshal public key: %w", err)
}
// Validate key length
if len(pub.Key) != ed25519.PublicKeySize {
return PublicKey{}, nil, fmt.Errorf("incorrect Ed25519 public key size: expected %d, got %d",
ed25519.PublicKeySize, len(pub.Key))
}
// Always recompute ID to ensure integrity
pub.Metadata.ID = computeKeyID(pub.Key)
return pub, rest, nil
}
type PrivateKey struct {
Key ed25519.PrivateKey
Metadata KeyMetadata
}
func parsePrivateKey(data []byte, typeTag string) (PrivateKey, error) {
b, rest := pem.Decode(data)
if b == nil {
return PrivateKey{}, errors.New("failed to decode PEM data")
}
if len(rest) > 0 {
return PrivateKey{}, errors.New("trailing PEM data")
}
if b.Type != typeTag {
return PrivateKey{}, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
}
// Unmarshal JSON-embedded format
var pk PrivateKey
if err := json.Unmarshal(b.Bytes, &pk); err != nil {
return PrivateKey{}, fmt.Errorf("failed to unmarshal private key: %w", err)
}
// Validate key length
if len(pk.Key) != ed25519.PrivateKeySize {
return PrivateKey{}, fmt.Errorf("incorrect Ed25519 private key size: expected %d, got %d",
ed25519.PrivateKeySize, len(pk.Key))
}
return pk, nil
}
func verifyAny(publicRootKeys []PublicKey, msg, sig []byte) bool {
// Verify with root keys
var rootKeys []ed25519.PublicKey
for _, r := range publicRootKeys {
rootKeys = append(rootKeys, r.Key)
}
for _, k := range rootKeys {
if ed25519.Verify(k, msg, sig) {
return true
}
}
return false
}

View File

@@ -0,0 +1,636 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/json"
"encoding/pem"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test KeyID functions
func TestComputeKeyID(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
// Verify it's the first 8 bytes of SHA-256
h := sha256.Sum256(pub)
expectedID := KeyID{}
copy(expectedID[:], h[:8])
assert.Equal(t, expectedID, keyID)
}
func TestComputeKeyID_Deterministic(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
// Computing KeyID multiple times should give the same result
keyID1 := computeKeyID(pub)
keyID2 := computeKeyID(pub)
assert.Equal(t, keyID1, keyID2)
}
func TestComputeKeyID_DifferentKeys(t *testing.T) {
pub1, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID1 := computeKeyID(pub1)
keyID2 := computeKeyID(pub2)
// Different keys should produce different IDs
assert.NotEqual(t, keyID1, keyID2)
}
func TestParseKeyID_Valid(t *testing.T) {
hexStr := "0123456789abcdef"
keyID, err := ParseKeyID(hexStr)
require.NoError(t, err)
expected := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
assert.Equal(t, expected, keyID)
}
func TestParseKeyID_InvalidLength(t *testing.T) {
tests := []struct {
name string
input string
}{
{"too short", "01234567"},
{"too long", "0123456789abcdef00"},
{"empty", ""},
{"odd length", "0123456789abcde"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ParseKeyID(tt.input)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid KeyID length")
})
}
}
func TestParseKeyID_InvalidHex(t *testing.T) {
invalidHex := "0123456789abcxyz" // 'xyz' are not valid hex
_, err := ParseKeyID(invalidHex)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode KeyID")
}
func TestKeyID_String(t *testing.T) {
keyID := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
str := keyID.String()
assert.Equal(t, "0123456789abcdef", str)
}
func TestKeyID_RoundTrip(t *testing.T) {
original := "fedcba9876543210"
keyID, err := ParseKeyID(original)
require.NoError(t, err)
result := keyID.String()
assert.Equal(t, original, result)
}
func TestKeyID_ZeroValue(t *testing.T) {
keyID := KeyID{}
str := keyID.String()
assert.Equal(t, "0000000000000000", str)
}
// Test KeyMetadata
func TestKeyMetadata_JSONMarshaling(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
ExpiresAt: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC),
}
jsonData, err := json.Marshal(metadata)
require.NoError(t, err)
var decoded KeyMetadata
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Equal(t, metadata.ID, decoded.ID)
assert.Equal(t, metadata.CreatedAt.Unix(), decoded.CreatedAt.Unix())
assert.Equal(t, metadata.ExpiresAt.Unix(), decoded.ExpiresAt.Unix())
}
func TestKeyMetadata_NoExpiration(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
ExpiresAt: time.Time{}, // Zero value = no expiration
}
jsonData, err := json.Marshal(metadata)
require.NoError(t, err)
var decoded KeyMetadata
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.True(t, decoded.ExpiresAt.IsZero())
}
// Test PublicKey
func TestPublicKey_JSONMarshaling(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
var decoded PublicKey
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Equal(t, pubKey.Key, decoded.Key)
assert.Equal(t, pubKey.Metadata.ID, decoded.Metadata.ID)
}
// Test parsePublicKey
func TestParsePublicKey_Valid(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
}
pubKey := PublicKey{
Key: pub,
Metadata: metadata,
}
// Marshal to JSON
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
// Encode to PEM
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
// Parse it back
parsed, rest, err := parsePublicKey(pemData, tagRootPublic)
require.NoError(t, err)
assert.Empty(t, rest)
assert.Equal(t, pub, parsed.Key)
assert.Equal(t, metadata.ID, parsed.Metadata.ID)
}
func TestParsePublicKey_InvalidPEM(t *testing.T) {
invalidPEM := []byte("not a PEM")
_, _, err := parsePublicKey(invalidPEM, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode PEM")
}
func TestParsePublicKey_WrongType(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
// Encode with wrong type
pemData := pem.EncodeToMemory(&pem.Block{
Type: "WRONG TYPE",
Bytes: jsonData,
})
_, _, err = parsePublicKey(pemData, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "PEM type")
}
func TestParsePublicKey_InvalidJSON(t *testing.T) {
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: []byte("invalid json"),
})
_, _, err := parsePublicKey(pemData, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to unmarshal")
}
func TestParsePublicKey_InvalidKeySize(t *testing.T) {
// Create a public key with wrong size
pubKey := PublicKey{
Key: []byte{0x01, 0x02, 0x03}, // Too short
Metadata: KeyMetadata{
ID: KeyID{},
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
_, _, err = parsePublicKey(pemData, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "incorrect Ed25519 public key size")
}
func TestParsePublicKey_IDRecomputation(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
// Create a public key with WRONG ID
wrongID := KeyID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: wrongID,
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
// Parse should recompute the correct ID
parsed, _, err := parsePublicKey(pemData, tagRootPublic)
require.NoError(t, err)
correctID := computeKeyID(pub)
assert.Equal(t, correctID, parsed.Metadata.ID)
assert.NotEqual(t, wrongID, parsed.Metadata.ID)
}
// Test parsePublicKeyBundle
func TestParsePublicKeyBundle_Single(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
keys, err := parsePublicKeyBundle(pemData, tagRootPublic)
require.NoError(t, err)
assert.Len(t, keys, 1)
assert.Equal(t, pub, keys[0].Key)
}
func TestParsePublicKeyBundle_Multiple(t *testing.T) {
var bundle []byte
// Create 3 keys
for i := 0; i < 3; i++ {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
bundle = append(bundle, pemData...)
}
keys, err := parsePublicKeyBundle(bundle, tagRootPublic)
require.NoError(t, err)
assert.Len(t, keys, 3)
}
func TestParsePublicKeyBundle_Empty(t *testing.T) {
_, err := parsePublicKeyBundle([]byte{}, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no keys found")
}
func TestParsePublicKeyBundle_Invalid(t *testing.T) {
_, err := parsePublicKeyBundle([]byte("invalid data"), tagRootPublic)
assert.Error(t, err)
}
// Test PrivateKey
func TestPrivateKey_JSONMarshaling(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
var decoded PrivateKey
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Equal(t, privKey.Key, decoded.Key)
assert.Equal(t, privKey.Metadata.ID, decoded.Metadata.ID)
}
// Test parsePrivateKey
func TestParsePrivateKey_Valid(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: jsonData,
})
parsed, err := parsePrivateKey(pemData, tagRootPrivate)
require.NoError(t, err)
assert.Equal(t, priv, parsed.Key)
}
func TestParsePrivateKey_InvalidPEM(t *testing.T) {
_, err := parsePrivateKey([]byte("not a PEM"), tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode PEM")
}
func TestParsePrivateKey_TrailingData(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: jsonData,
})
// Add trailing data
pemData = append(pemData, []byte("extra data")...)
_, err = parsePrivateKey(pemData, tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "trailing PEM data")
}
func TestParsePrivateKey_WrongType(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: "WRONG TYPE",
Bytes: jsonData,
})
_, err = parsePrivateKey(pemData, tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "PEM type")
}
func TestParsePrivateKey_InvalidKeySize(t *testing.T) {
privKey := PrivateKey{
Key: []byte{0x01, 0x02, 0x03}, // Too short
Metadata: KeyMetadata{
ID: KeyID{},
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: jsonData,
})
_, err = parsePrivateKey(pemData, tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
}
// Test verifyAny
func TestVerifyAny_ValidSignature(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv, message)
rootKeys := []PublicKey{
{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
},
}
result := verifyAny(rootKeys, message, signature)
assert.True(t, result)
}
func TestVerifyAny_InvalidSignature(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
invalidSignature := make([]byte, ed25519.SignatureSize)
rootKeys := []PublicKey{
{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
},
}
result := verifyAny(rootKeys, message, invalidSignature)
assert.False(t, result)
}
func TestVerifyAny_MultipleKeys(t *testing.T) {
// Create 3 key pairs
pub1, priv1, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub3, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv1, message)
rootKeys := []PublicKey{
{Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
{Key: pub1, Metadata: KeyMetadata{ID: computeKeyID(pub1)}}, // Correct key in middle
{Key: pub3, Metadata: KeyMetadata{ID: computeKeyID(pub3)}},
}
result := verifyAny(rootKeys, message, signature)
assert.True(t, result)
}
func TestVerifyAny_NoMatchingKey(t *testing.T) {
_, priv1, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv1, message)
// Only include pub2, not pub1
rootKeys := []PublicKey{
{Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
}
result := verifyAny(rootKeys, message, signature)
assert.False(t, result)
}
func TestVerifyAny_EmptyKeys(t *testing.T) {
message := []byte("test message")
signature := make([]byte, ed25519.SignatureSize)
result := verifyAny([]PublicKey{}, message, signature)
assert.False(t, result)
}
func TestVerifyAny_TamperedMessage(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv, message)
rootKeys := []PublicKey{
{Key: pub, Metadata: KeyMetadata{ID: computeKeyID(pub)}},
}
// Verify with different message
tamperedMessage := []byte("different message")
result := verifyAny(rootKeys, tamperedMessage, signature)
assert.False(t, result)
}

Some files were not shown because too many files have changed in this diff Show More