mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
Compare commits
51 Commits
v0.60.1
...
test-freeb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea997f4a26 | ||
|
|
73201c4f3e | ||
|
|
33d1761fe8 | ||
|
|
aa914a0f26 | ||
|
|
ab6a9e85de | ||
|
|
d3b123c76d | ||
|
|
fc4932a23f | ||
|
|
b7e98acd1f | ||
|
|
433bc4ead9 | ||
|
|
011cc81678 | ||
|
|
537151e0f3 | ||
|
|
a9c28ef723 | ||
|
|
c29bb1a289 | ||
|
|
447cd287f5 | ||
|
|
5748bdd64e | ||
|
|
08f31fbcb3 | ||
|
|
932c02eaab | ||
|
|
abcbde26f9 | ||
|
|
90e3b8009f | ||
|
|
94d34dc0c5 | ||
|
|
44851e06fb | ||
|
|
3f4f825ec1 | ||
|
|
f538e6e9ae | ||
|
|
cb6b086164 | ||
|
|
71b6855e09 | ||
|
|
9bdc4908fb | ||
|
|
031ab11178 | ||
|
|
d2e48d4f5e | ||
|
|
27dd97c9c4 | ||
|
|
e87b4ace11 | ||
|
|
a232cf614c | ||
|
|
a293f760af | ||
|
|
10e9cf8c62 | ||
|
|
7193bd2da7 | ||
|
|
52948ccd61 | ||
|
|
4b77359042 | ||
|
|
387d43bcc1 | ||
|
|
e47d815dd2 | ||
|
|
cb83b7c0d3 | ||
|
|
ddcd182859 | ||
|
|
aca0398105 | ||
|
|
02200d790b | ||
|
|
f31bba87b4 | ||
|
|
7285fef0f0 | ||
|
|
20973063d8 | ||
|
|
ba2e9b6d88 | ||
|
|
131d7a3694 | ||
|
|
290fe2d8b9 | ||
|
|
7fb1a2fe31 | ||
|
|
32146e576d | ||
|
|
1311364397 |
11
.githooks/pre-push
Executable file
11
.githooks/pre-push
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "Running pre-push hook..."
|
||||
if ! make lint; then
|
||||
echo ""
|
||||
echo "Hint: To push without verification, run:"
|
||||
echo " git push --no-verify"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All checks passed!"
|
||||
7
.github/workflows/golang-test-darwin.yml
vendored
7
.github/workflows/golang-test-darwin.yml
vendored
@@ -15,13 +15,14 @@ jobs:
|
||||
name: "Client / Unit"
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
|
||||
2
.github/workflows/golang-test-freebsd.yml
vendored
2
.github/workflows/golang-test-freebsd.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
release: "14.2"
|
||||
prepare: |
|
||||
pkg install -y curl pkgconf xorg
|
||||
GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
|
||||
71
.github/workflows/golang-test-linux.yml
vendored
71
.github/workflows/golang-test-linux.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Get Go environment
|
||||
@@ -106,15 +106,15 @@ jobs:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -151,15 +151,15 @@ jobs:
|
||||
needs: [ build-cache ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
id: go-env
|
||||
run: |
|
||||
@@ -200,7 +200,7 @@ jobs:
|
||||
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||
-e CONTAINER=${CONTAINER} \
|
||||
golang:1.23-alpine \
|
||||
golang:1.24-alpine \
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
@@ -220,15 +220,15 @@ jobs:
|
||||
raceFlag: "-race"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
@@ -270,15 +270,15 @@ jobs:
|
||||
arch: [ '386','amd64' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||
@@ -321,15 +321,15 @@ jobs:
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -408,15 +408,16 @@ jobs:
|
||||
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||
-p 9090:9090 \
|
||||
prom/prometheus
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -497,15 +498,15 @@ jobs:
|
||||
-p 9090:9090 \
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
@@ -561,15 +562,15 @@ jobs:
|
||||
store: [ 'sqlite', 'postgres']
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
|
||||
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
uses: actions/setup-go@v5
|
||||
id: go
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Get Go environment
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -46,7 +46,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Install dependencies
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
with:
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
- name: Setup NDK
|
||||
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build android netbird lib
|
||||
@@ -56,9 +56,9 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
|
||||
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20251113184115-a159579294ab
|
||||
- name: gomobile init
|
||||
run: gomobile init
|
||||
- name: build iOS netbird lib
|
||||
|
||||
104
.github/workflows/release.yml
vendored
104
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.23"
|
||||
SIGN_PIPE_VER: "v0.1.0"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -19,8 +19,102 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
release_freebsd_port:
|
||||
name: "FreeBSD Port / Build & Test"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
|
||||
- name: Generate FreeBSD port issue body
|
||||
run: bash release_files/freebsd-port-issue-body.sh
|
||||
|
||||
- name: Check if diff was generated
|
||||
id: check_diff
|
||||
run: |
|
||||
if ls netbird-*.diff 1> /dev/null 2>&1; then
|
||||
echo "diff_exists=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "diff_exists=false" >> $GITHUB_OUTPUT
|
||||
echo "No diff file generated (port may already be up to date)"
|
||||
fi
|
||||
|
||||
- name: Extract version
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
id: version
|
||||
run: |
|
||||
VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/')
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "Generated files for version: $VERSION"
|
||||
cat netbird-*.diff
|
||||
|
||||
- name: Test FreeBSD port
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "15.0"
|
||||
prepare: |
|
||||
# Install required packages
|
||||
pkg install -y git curl portlint go
|
||||
|
||||
# Install Go for building
|
||||
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -LO "$GO_URL"
|
||||
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||
|
||||
# Clone ports tree (shallow, only what we need)
|
||||
git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports
|
||||
cd /usr/ports
|
||||
|
||||
run: |
|
||||
set -e -x
|
||||
export PATH=$PATH:/usr/local/go/bin
|
||||
|
||||
# Find the diff file
|
||||
echo "Finding diff file..."
|
||||
DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1)
|
||||
echo "Found: $DIFF_FILE"
|
||||
|
||||
if [[ -z "$DIFF_FILE" ]]; then
|
||||
echo "ERROR: Could not find diff file"
|
||||
find ~ -name "*.diff" -type f 2>/dev/null || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths)
|
||||
cd /usr/ports
|
||||
patch -p1 -V none < "$DIFF_FILE"
|
||||
|
||||
# Show patched Makefile
|
||||
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
|
||||
|
||||
cd /usr/ports/security/netbird
|
||||
export BATCH=yes
|
||||
make package
|
||||
pkg add ./work/pkg/netbird-*.pkg
|
||||
|
||||
netbird version | grep "$version"
|
||||
|
||||
echo "FreeBSD port test completed successfully!"
|
||||
|
||||
- name: Upload FreeBSD port files
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: freebsd-port-files
|
||||
path: |
|
||||
./netbird-*-issue.txt
|
||||
./netbird-*.diff
|
||||
retention-days: 30
|
||||
|
||||
release:
|
||||
runs-on: ubuntu-latest-m
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
@@ -40,7 +134,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -136,7 +230,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -200,7 +294,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23"
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
|
||||
@@ -67,10 +67,13 @@ jobs:
|
||||
- name: Install curl
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -80,9 +83,6 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup MySQL privileges
|
||||
if: matrix.store == 'mysql'
|
||||
run: |
|
||||
|
||||
8
.github/workflows/wasm-build-validation.yml
vendored
8
.github/workflows/wasm-build-validation.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||
env:
|
||||
@@ -60,8 +60,8 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 52428800 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
|
||||
if [ ${SIZE} -gt 57671680 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -136,6 +136,14 @@ checked out and set up:
|
||||
go mod tidy
|
||||
```
|
||||
|
||||
6. Configure Git hooks for automatic linting:
|
||||
|
||||
```bash
|
||||
make setup-hooks
|
||||
```
|
||||
|
||||
This will configure Git to run linting automatically before each push, helping catch issues early.
|
||||
|
||||
### Dev Container Support
|
||||
|
||||
If you prefer using a dev container for development, NetBird now includes support for dev containers.
|
||||
|
||||
27
Makefile
Normal file
27
Makefile
Normal file
@@ -0,0 +1,27 @@
|
||||
.PHONY: lint lint-all lint-install setup-hooks
|
||||
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||
|
||||
# Install golangci-lint locally if needed
|
||||
$(GOLANGCI_LINT):
|
||||
@echo "Installing golangci-lint..."
|
||||
@mkdir -p ./bin
|
||||
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# Lint only changed files (fast, for pre-push)
|
||||
lint: $(GOLANGCI_LINT)
|
||||
@echo "Running lint on changed files..."
|
||||
@$(GOLANGCI_LINT) run --new-from-rev=origin/main --timeout=2m
|
||||
|
||||
# Lint entire codebase (slow, matches CI)
|
||||
lint-all: $(GOLANGCI_LINT)
|
||||
@echo "Running lint on all files..."
|
||||
@$(GOLANGCI_LINT) run --timeout=12m
|
||||
|
||||
# Just install the linter
|
||||
lint-install: $(GOLANGCI_LINT)
|
||||
|
||||
# Setup git hooks for all developers
|
||||
setup-hooks:
|
||||
@git config core.hooksPath .githooks
|
||||
@chmod +x .githooks/pre-push
|
||||
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
<div align="center">
|
||||
<br/>
|
||||
<br/>
|
||||
@@ -113,7 +112,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
|
||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||
|
||||
<p float="left" align="middle">
|
||||
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
|
||||
</p>
|
||||
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
@@ -4,10 +4,13 @@ package android
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -16,10 +19,13 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
@@ -53,7 +59,6 @@ func init() {
|
||||
|
||||
// Client struct manage the life circle of background service
|
||||
type Client struct {
|
||||
cfgFile string
|
||||
tunAdapter device.TunAdapter
|
||||
iFaceDiscover IFaceDiscover
|
||||
recorder *peer.Status
|
||||
@@ -67,12 +72,11 @@ type Client struct {
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||
execWorkaround(androidSDKVersion)
|
||||
|
||||
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||
return &Client{
|
||||
cfgFile: cfgFile,
|
||||
deviceName: deviceName,
|
||||
uiVersion: uiVersion,
|
||||
tunAdapter: tunAdapter,
|
||||
@@ -84,10 +88,16 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
|
||||
}
|
||||
|
||||
// Run start the internal client. It is a blocker function
|
||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
exportEnvList(envList)
|
||||
|
||||
cfgFile := platformFiles.ConfigurationFilePath()
|
||||
stateFile := platformFiles.StateFilePath()
|
||||
|
||||
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
||||
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
ConfigPath: cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -107,23 +117,29 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
||||
c.ctxCancelLock.Unlock()
|
||||
|
||||
auth := NewAuthWithConfig(ctx, cfg)
|
||||
err = auth.login(urlOpener)
|
||||
err = auth.login(urlOpener, isAndroidTV)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
// In this case make no sense handle registration steps.
|
||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||
exportEnvList(envList)
|
||||
|
||||
cfgFile := platformFiles.ConfigurationFilePath()
|
||||
stateFile := platformFiles.StateFilePath()
|
||||
|
||||
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
||||
|
||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
ConfigPath: cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -141,8 +157,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -156,6 +172,19 @@ func (c *Client) Stop() {
|
||||
c.ctxCancel()
|
||||
}
|
||||
|
||||
func (c *Client) RenewTun(fd int) error {
|
||||
if c.connectClient == nil {
|
||||
return fmt.Errorf("engine not running")
|
||||
}
|
||||
|
||||
e := c.connectClient.Engine()
|
||||
if e == nil {
|
||||
return fmt.Errorf("engine not initialized")
|
||||
}
|
||||
|
||||
return e.RenewTun(fd)
|
||||
}
|
||||
|
||||
// SetTraceLogLevel configure the logger to trace level
|
||||
func (c *Client) SetTraceLogLevel() {
|
||||
log.SetLevel(log.TraceLevel)
|
||||
@@ -177,6 +206,7 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
p.IP,
|
||||
p.FQDN,
|
||||
p.ConnStatus.String(),
|
||||
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||
}
|
||||
peerInfos[n] = pi
|
||||
}
|
||||
@@ -201,31 +231,43 @@ func (c *Client) Networks() *NetworkArray {
|
||||
return nil
|
||||
}
|
||||
|
||||
routeSelector := routeManager.GetRouteSelector()
|
||||
if routeSelector == nil {
|
||||
log.Error("could not get route selector")
|
||||
return nil
|
||||
}
|
||||
|
||||
networkArray := &NetworkArray{
|
||||
items: make([]Network, 0),
|
||||
}
|
||||
|
||||
resolvedDomains := c.recorder.GetResolvedDomainsStates()
|
||||
|
||||
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||
if len(routes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
r := routes[0]
|
||||
domains := c.getNetworkDomainsFromRoute(r, resolvedDomains)
|
||||
netStr := r.Network.String()
|
||||
|
||||
if r.IsDynamic() {
|
||||
netStr = r.Domains.SafeString()
|
||||
}
|
||||
|
||||
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
routePeer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||
if err != nil {
|
||||
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||
continue
|
||||
}
|
||||
network := Network{
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: peer.FQDN,
|
||||
Status: peer.ConnStatus.String(),
|
||||
Name: string(id),
|
||||
Network: netStr,
|
||||
Peer: routePeer.FQDN,
|
||||
Status: routePeer.ConnStatus.String(),
|
||||
IsSelected: routeSelector.IsSelected(id),
|
||||
Domains: domains,
|
||||
}
|
||||
networkArray.Add(network)
|
||||
}
|
||||
@@ -253,6 +295,69 @@ func (c *Client) RemoveConnectionListener() {
|
||||
c.recorder.RemoveConnectionListener()
|
||||
}
|
||||
|
||||
func (c *Client) toggleRoute(command routeCommand) error {
|
||||
return command.toggleRoute()
|
||||
}
|
||||
|
||||
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
||||
client := c.connectClient
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
engine := client.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("engine is not running")
|
||||
}
|
||||
|
||||
manager := engine.GetRouteManager()
|
||||
if manager == nil {
|
||||
return nil, fmt.Errorf("could not get route manager")
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func (c *Client) SelectRoute(route string) error {
|
||||
manager, err := c.getRouteManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.toggleRoute(selectRouteCommand{route: route, manager: manager})
|
||||
}
|
||||
|
||||
func (c *Client) DeselectRoute(route string) error {
|
||||
manager, err := c.getRouteManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.toggleRoute(deselectRouteCommand{route: route, manager: manager})
|
||||
}
|
||||
|
||||
// getNetworkDomainsFromRoute extracts domains from a route and enriches each domain
|
||||
// with its resolved IP addresses from the provided resolvedDomains map.
|
||||
func (c *Client) getNetworkDomainsFromRoute(route *route.Route, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) NetworkDomains {
|
||||
domains := NetworkDomains{}
|
||||
|
||||
for _, d := range route.Domains {
|
||||
networkDomain := NetworkDomain{
|
||||
Address: d.SafeString(),
|
||||
}
|
||||
|
||||
if info, exists := resolvedDomains[d]; exists {
|
||||
for _, prefix := range info.Prefixes {
|
||||
networkDomain.addResolvedIP(prefix.Addr().String())
|
||||
}
|
||||
}
|
||||
|
||||
domains.Add(&networkDomain)
|
||||
}
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
func exportEnvList(list *EnvList) {
|
||||
if list == nil {
|
||||
return
|
||||
|
||||
@@ -32,7 +32,7 @@ type ErrListener interface {
|
||||
// URLOpener it is a callback interface. The Open function will be triggered if
|
||||
// the backend want to show an url for the user
|
||||
type URLOpener interface {
|
||||
Open(string)
|
||||
Open(url string, userCode string)
|
||||
OnLoginSuccess()
|
||||
}
|
||||
|
||||
@@ -148,9 +148,9 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
||||
}
|
||||
|
||||
// Login try register the client on the server
|
||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
||||
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidTV bool) {
|
||||
go func() {
|
||||
err := a.login(urlOpener)
|
||||
err := a.login(urlOpener, isAndroidTV)
|
||||
if err != nil {
|
||||
resultListener.OnError(err)
|
||||
} else {
|
||||
@@ -159,7 +159,7 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Auth) login(urlOpener URLOpener) error {
|
||||
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||
var needsLogin bool
|
||||
|
||||
// check if we need to generate JWT token
|
||||
@@ -173,7 +173,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
|
||||
jwtToken := ""
|
||||
if needsLogin {
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener)
|
||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
@@ -199,8 +199,8 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -210,7 +210,7 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, err
|
||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||
}
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete)
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
||||
|
||||
56
client/android/network_domains.go
Normal file
56
client/android/network_domains.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import "fmt"
|
||||
|
||||
type ResolvedIPs struct {
|
||||
resolvedIPs []string
|
||||
}
|
||||
|
||||
func (r *ResolvedIPs) Add(ipAddress string) {
|
||||
r.resolvedIPs = append(r.resolvedIPs, ipAddress)
|
||||
}
|
||||
|
||||
func (r *ResolvedIPs) Get(i int) (string, error) {
|
||||
if i < 0 || i >= len(r.resolvedIPs) {
|
||||
return "", fmt.Errorf("%d is out of range", i)
|
||||
}
|
||||
return r.resolvedIPs[i], nil
|
||||
}
|
||||
|
||||
func (r *ResolvedIPs) Size() int {
|
||||
return len(r.resolvedIPs)
|
||||
}
|
||||
|
||||
type NetworkDomain struct {
|
||||
Address string
|
||||
resolvedIPs ResolvedIPs
|
||||
}
|
||||
|
||||
func (d *NetworkDomain) addResolvedIP(resolvedIP string) {
|
||||
d.resolvedIPs.Add(resolvedIP)
|
||||
}
|
||||
|
||||
func (d *NetworkDomain) GetResolvedIPs() *ResolvedIPs {
|
||||
return &d.resolvedIPs
|
||||
}
|
||||
|
||||
type NetworkDomains struct {
|
||||
domains []*NetworkDomain
|
||||
}
|
||||
|
||||
func (n *NetworkDomains) Add(domain *NetworkDomain) {
|
||||
n.domains = append(n.domains, domain)
|
||||
}
|
||||
|
||||
func (n *NetworkDomains) Get(i int) (*NetworkDomain, error) {
|
||||
if i < 0 || i >= len(n.domains) {
|
||||
return nil, fmt.Errorf("%d is out of range", i)
|
||||
}
|
||||
return n.domains[i], nil
|
||||
}
|
||||
|
||||
func (n *NetworkDomains) Size() int {
|
||||
return len(n.domains)
|
||||
}
|
||||
@@ -3,10 +3,16 @@
|
||||
package android
|
||||
|
||||
type Network struct {
|
||||
Name string
|
||||
Network string
|
||||
Peer string
|
||||
Status string
|
||||
Name string
|
||||
Network string
|
||||
Peer string
|
||||
Status string
|
||||
IsSelected bool
|
||||
Domains NetworkDomains
|
||||
}
|
||||
|
||||
func (n Network) GetNetworkDomains() *NetworkDomains {
|
||||
return &n.Domains
|
||||
}
|
||||
|
||||
type NetworkArray struct {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||
@@ -5,6 +7,11 @@ type PeerInfo struct {
|
||||
IP string
|
||||
FQDN string
|
||||
ConnStatus string // Todo replace to enum
|
||||
Routes PeerRoutes
|
||||
}
|
||||
|
||||
func (p *PeerInfo) GetPeerRoutes() *PeerRoutes {
|
||||
return &p.Routes
|
||||
}
|
||||
|
||||
// PeerInfoArray is a wrapper of []PeerInfo
|
||||
|
||||
20
client/android/peer_routes.go
Normal file
20
client/android/peer_routes.go
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import "fmt"
|
||||
|
||||
type PeerRoutes struct {
|
||||
routes []string
|
||||
}
|
||||
|
||||
func (p *PeerRoutes) Get(i int) (string, error) {
|
||||
if i < 0 || i >= len(p.routes) {
|
||||
return "", fmt.Errorf("%d is out of range", i)
|
||||
}
|
||||
return p.routes[i], nil
|
||||
}
|
||||
|
||||
func (p *PeerRoutes) Size() int {
|
||||
return len(p.routes)
|
||||
}
|
||||
10
client/android/platform_files.go
Normal file
10
client/android/platform_files.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
// PlatformFiles groups paths to files used internally by the engine that can't be created/modified
|
||||
// at their default locations due to android OS restrictions.
|
||||
type PlatformFiles interface {
|
||||
ConfigurationFilePath() string
|
||||
StateFilePath() string
|
||||
}
|
||||
257
client/android/profile_manager.go
Normal file
257
client/android/profile_manager.go
Normal file
@@ -0,0 +1,257 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
// Android-specific config filename (different from desktop default.json)
|
||||
defaultConfigFilename = "netbird.cfg"
|
||||
// Subdirectory for non-default profiles (must match Java Preferences.java)
|
||||
profilesSubdir = "profiles"
|
||||
// Android uses a single user context per app (non-empty username required by ServiceManager)
|
||||
androidUsername = "android"
|
||||
)
|
||||
|
||||
// Profile represents a profile for gomobile
|
||||
type Profile struct {
|
||||
Name string
|
||||
IsActive bool
|
||||
}
|
||||
|
||||
// ProfileArray wraps profiles for gomobile compatibility
|
||||
type ProfileArray struct {
|
||||
items []*Profile
|
||||
}
|
||||
|
||||
// Length returns the number of profiles
|
||||
func (p *ProfileArray) Length() int {
|
||||
return len(p.items)
|
||||
}
|
||||
|
||||
// Get returns the profile at index i
|
||||
func (p *ProfileArray) Get(i int) *Profile {
|
||||
if i < 0 || i >= len(p.items) {
|
||||
return nil
|
||||
}
|
||||
return p.items[i]
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
/data/data/io.netbird.client/files/ ← configDir parameter
|
||||
├── netbird.cfg ← Default profile config
|
||||
├── state.json ← Default profile state
|
||||
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
|
||||
└── profiles/ ← Subdirectory for non-default profiles
|
||||
├── work.json ← Work profile config
|
||||
├── work.state.json ← Work profile state
|
||||
├── personal.json ← Personal profile config
|
||||
└── personal.state.json ← Personal profile state
|
||||
*/
|
||||
|
||||
// ProfileManager manages profiles for Android
|
||||
// It wraps the internal profilemanager to provide Android-specific behavior
|
||||
type ProfileManager struct {
|
||||
configDir string
|
||||
serviceMgr *profilemanager.ServiceManager
|
||||
}
|
||||
|
||||
// NewProfileManager creates a new profile manager for Android
|
||||
func NewProfileManager(configDir string) *ProfileManager {
|
||||
// Set the default config path for Android (stored in root configDir, not profiles/)
|
||||
defaultConfigPath := filepath.Join(configDir, defaultConfigFilename)
|
||||
|
||||
// Set global paths for Android
|
||||
profilemanager.DefaultConfigPathDir = configDir
|
||||
profilemanager.DefaultConfigPath = defaultConfigPath
|
||||
profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json")
|
||||
|
||||
// Create ServiceManager with profiles/ subdirectory
|
||||
// This avoids modifying the global ConfigDirOverride for profile listing
|
||||
profilesDir := filepath.Join(configDir, profilesSubdir)
|
||||
serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir)
|
||||
|
||||
return &ProfileManager{
|
||||
configDir: configDir,
|
||||
serviceMgr: serviceMgr,
|
||||
}
|
||||
}
|
||||
|
||||
// ListProfiles returns all available profiles
|
||||
func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
|
||||
// Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive)
|
||||
internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list profiles: %w", err)
|
||||
}
|
||||
|
||||
// Convert internal profiles to Android Profile type
|
||||
var profiles []*Profile
|
||||
for _, p := range internalProfiles {
|
||||
profiles = append(profiles, &Profile{
|
||||
Name: p.Name,
|
||||
IsActive: p.IsActive,
|
||||
})
|
||||
}
|
||||
|
||||
return &ProfileArray{items: profiles}, nil
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the currently active profile name
|
||||
func (pm *ProfileManager) GetActiveProfile() (string, error) {
|
||||
// Use ServiceManager to stay consistent with ListProfiles
|
||||
// ServiceManager uses active_profile.json
|
||||
activeState, err := pm.serviceMgr.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return activeState.Name, nil
|
||||
}
|
||||
|
||||
// SwitchProfile switches to a different profile
|
||||
func (pm *ProfileManager) SwitchProfile(profileName string) error {
|
||||
// Use ServiceManager to stay consistent with ListProfiles
|
||||
// ServiceManager uses active_profile.json
|
||||
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profileName,
|
||||
Username: androidUsername,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to switch profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("switched to profile: %s", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddProfile creates a new profile
|
||||
func (pm *ProfileManager) AddProfile(profileName string) error {
|
||||
// Use ServiceManager (creates profile in profiles/ directory)
|
||||
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
|
||||
return fmt.Errorf("failed to add profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("created new profile: %s", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogoutProfile logs out from a profile (clears authentication)
|
||||
func (pm *ProfileManager) LogoutProfile(profileName string) error {
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
|
||||
configPath, err := pm.getProfileConfigPath(profileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if profile exists
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("profile '%s' does not exist", profileName)
|
||||
}
|
||||
|
||||
// Read current config using internal profilemanager
|
||||
config, err := profilemanager.ReadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read profile config: %w", err)
|
||||
}
|
||||
|
||||
// Clear authentication by removing private key and SSH key
|
||||
config.PrivateKey = ""
|
||||
config.SSHKey = ""
|
||||
|
||||
// Save config using internal profilemanager
|
||||
if err := profilemanager.WriteOutConfig(configPath, config); err != nil {
|
||||
return fmt.Errorf("failed to save config: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("logged out from profile: %s", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveProfile deletes a profile
|
||||
func (pm *ProfileManager) RemoveProfile(profileName string) error {
|
||||
// Use ServiceManager (removes profile from profiles/ directory)
|
||||
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
|
||||
return fmt.Errorf("failed to remove profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("removed profile: %s", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getProfileConfigPath returns the config file path for a profile
|
||||
// This is needed for Android-specific path handling (netbird.cfg for default profile)
|
||||
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
|
||||
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||
// Android uses netbird.cfg for default profile instead of default.json
|
||||
// Default profile is stored in root configDir, not in profiles/
|
||||
return filepath.Join(pm.configDir, defaultConfigFilename), nil
|
||||
}
|
||||
|
||||
// Non-default profiles are stored in profiles subdirectory
|
||||
// This matches the Java Preferences.java expectation
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||
return filepath.Join(profilesDir, profileName+".json"), nil
|
||||
}
|
||||
|
||||
// GetConfigPath returns the config file path for a given profile
|
||||
// Java should call this instead of constructing paths with Preferences.configFile()
|
||||
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
|
||||
return pm.getProfileConfigPath(profileName)
|
||||
}
|
||||
|
||||
// GetStateFilePath returns the state file path for a given profile
|
||||
// Java should call this instead of constructing paths with Preferences.stateFile()
|
||||
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
|
||||
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||
return filepath.Join(pm.configDir, "state.json"), nil
|
||||
}
|
||||
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||
return filepath.Join(profilesDir, profileName+".state.json"), nil
|
||||
}
|
||||
|
||||
// GetActiveConfigPath returns the config file path for the currently active profile
|
||||
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile()
|
||||
func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
|
||||
activeProfile, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return pm.GetConfigPath(activeProfile)
|
||||
}
|
||||
|
||||
// GetActiveStateFilePath returns the state file path for the currently active profile
|
||||
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile()
|
||||
func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
|
||||
activeProfile, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return pm.GetStateFilePath(activeProfile)
|
||||
}
|
||||
|
||||
// sanitizeProfileName removes invalid characters from profile name
|
||||
func sanitizeProfileName(name string) string {
|
||||
// Keep only alphanumeric, underscore, and hyphen
|
||||
var result strings.Builder
|
||||
for _, r := range name {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
|
||||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
67
client/android/route_command.go
Normal file
67
client/android/route_command.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func executeRouteToggle(id string, manager routemanager.Manager,
|
||||
operationName string,
|
||||
routeOperation func(routes []route.NetID, allRoutes []route.NetID) error) error {
|
||||
netID := route.NetID(id)
|
||||
routes := []route.NetID{netID}
|
||||
|
||||
log.Debugf("%s with id: %s", operationName, id)
|
||||
|
||||
if err := routeOperation(routes, maps.Keys(manager.GetClientRoutesWithNetID())); err != nil {
|
||||
log.Debugf("error when %s: %s", operationName, err)
|
||||
return fmt.Errorf("error %s: %w", operationName, err)
|
||||
}
|
||||
|
||||
manager.TriggerSelection(manager.GetClientRoutes())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type routeCommand interface {
|
||||
toggleRoute() error
|
||||
}
|
||||
|
||||
type selectRouteCommand struct {
|
||||
route string
|
||||
manager routemanager.Manager
|
||||
}
|
||||
|
||||
func (s selectRouteCommand) toggleRoute() error {
|
||||
routeSelector := s.manager.GetRouteSelector()
|
||||
if routeSelector == nil {
|
||||
return fmt.Errorf("no route selector available")
|
||||
}
|
||||
|
||||
routeOperation := func(routes []route.NetID, allRoutes []route.NetID) error {
|
||||
return routeSelector.SelectRoutes(routes, true, allRoutes)
|
||||
}
|
||||
|
||||
return executeRouteToggle(s.route, s.manager, "selecting route", routeOperation)
|
||||
}
|
||||
|
||||
type deselectRouteCommand struct {
|
||||
route string
|
||||
manager routemanager.Manager
|
||||
}
|
||||
|
||||
func (d deselectRouteCommand) toggleRoute() error {
|
||||
routeSelector := d.manager.GetRouteSelector()
|
||||
if routeSelector == nil {
|
||||
return fmt.Errorf("no route selector available")
|
||||
}
|
||||
|
||||
return executeRouteToggle(d.route, d.manager, "deselecting route", routeSelector.DeselectRoutes)
|
||||
}
|
||||
@@ -4,14 +4,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
@@ -332,7 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
||||
hint = profileState.Email
|
||||
}
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), false, hint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -373,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
cmd.Println("")
|
||||
|
||||
if !noBrowser {
|
||||
if err := openBrowser(verificationURIComplete); err != nil {
|
||||
if err := util.OpenBrowser(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
|
||||
func openBrowser(url string) error {
|
||||
if browser := os.Getenv("BROWSER"); browser != "" {
|
||||
return exec.Command(browser, url).Start()
|
||||
}
|
||||
return open.Run(url)
|
||||
}
|
||||
|
||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
|
||||
@@ -85,6 +85,9 @@ var (
|
||||
|
||||
// Execute executes the root command.
|
||||
func Execute() error {
|
||||
if isUpdateBinary() {
|
||||
return updateCmd.Execute()
|
||||
}
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
|
||||
176
client/cmd/signer/artifactkey.go
Normal file
176
client/cmd/signer/artifactkey.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
var (
|
||||
bundlePubKeysRootPrivKeyFile string
|
||||
bundlePubKeysPubKeyFiles []string
|
||||
bundlePubKeysFile string
|
||||
|
||||
createArtifactKeyRootPrivKeyFile string
|
||||
createArtifactKeyPrivKeyFile string
|
||||
createArtifactKeyPubKeyFile string
|
||||
createArtifactKeyExpiration time.Duration
|
||||
)
|
||||
|
||||
var createArtifactKeyCmd = &cobra.Command{
|
||||
Use: "create-artifact-key",
|
||||
Short: "Create a new artifact signing key",
|
||||
Long: `Generate a new artifact signing key pair signed by the root private key.
|
||||
The artifact key will be used to sign software artifacts/updates.`,
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if createArtifactKeyExpiration <= 0 {
|
||||
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
|
||||
}
|
||||
|
||||
if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil {
|
||||
return fmt.Errorf("failed to create artifact key: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var bundlePubKeysCmd = &cobra.Command{
|
||||
Use: "bundle-pub-keys",
|
||||
Short: "Bundle multiple artifact public keys into a signed package",
|
||||
Long: `Bundle one or more artifact public keys into a signed package using the root private key.
|
||||
This command is typically used to distribute or authorize a set of valid artifact signing keys.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if len(bundlePubKeysPubKeyFiles) == 0 {
|
||||
return fmt.Errorf("at least one --artifact-pub-key-file must be provided")
|
||||
}
|
||||
|
||||
if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil {
|
||||
return fmt.Errorf("failed to bundle public keys: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(createArtifactKeyCmd)
|
||||
|
||||
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key")
|
||||
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved")
|
||||
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved")
|
||||
createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)")
|
||||
|
||||
if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
|
||||
}
|
||||
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err))
|
||||
}
|
||||
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
|
||||
}
|
||||
if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil {
|
||||
panic(fmt.Errorf("mark expiration as required: %w", err))
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(bundlePubKeysCmd)
|
||||
|
||||
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle")
|
||||
bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)")
|
||||
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved")
|
||||
|
||||
if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
|
||||
}
|
||||
if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
|
||||
}
|
||||
if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error {
|
||||
cmd.Println("Creating new artifact signing key...")
|
||||
|
||||
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read root private key file: %w", err)
|
||||
}
|
||||
|
||||
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||
}
|
||||
|
||||
artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate artifact key: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err)
|
||||
}
|
||||
|
||||
signatureFile := artifactPubKeyFile + ".sig"
|
||||
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
|
||||
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
|
||||
}
|
||||
|
||||
cmd.Printf("✅ Artifact key created successfully.\n")
|
||||
cmd.Printf("%s\n", artifactKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error {
|
||||
cmd.Println("📦 Bundling public keys into signed package...")
|
||||
|
||||
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read root private key file: %w", err)
|
||||
}
|
||||
|
||||
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||
}
|
||||
|
||||
publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles))
|
||||
for _, pubFile := range artifactPubKeyFiles {
|
||||
pubPem, err := os.ReadFile(pubFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read public key file: %w", err)
|
||||
}
|
||||
|
||||
pk, err := reposign.ParseArtifactPubKey(pubPem)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse artifact key: %w", err)
|
||||
}
|
||||
publicKeys = append(publicKeys, pk)
|
||||
}
|
||||
|
||||
parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bundle artifact keys: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil {
|
||||
return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err)
|
||||
}
|
||||
|
||||
signatureFile := bundlePubKeysFile + ".sig"
|
||||
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
|
||||
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
|
||||
}
|
||||
|
||||
cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles))
|
||||
return nil
|
||||
}
|
||||
276
client/cmd/signer/artifactsign.go
Normal file
276
client/cmd/signer/artifactsign.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
const (
|
||||
envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY"
|
||||
)
|
||||
|
||||
var (
|
||||
signArtifactPrivKeyFile string
|
||||
signArtifactArtifactFile string
|
||||
|
||||
verifyArtifactPubKeyFile string
|
||||
verifyArtifactFile string
|
||||
verifyArtifactSignatureFile string
|
||||
|
||||
verifyArtifactKeyPubKeyFile string
|
||||
verifyArtifactKeyRootPubKeyFile string
|
||||
verifyArtifactKeySignatureFile string
|
||||
verifyArtifactKeyRevocationFile string
|
||||
)
|
||||
|
||||
var signArtifactCmd = &cobra.Command{
|
||||
Use: "sign-artifact",
|
||||
Short: "Sign an artifact using an artifact private key",
|
||||
Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key.
|
||||
This command produces a detached signature that can be verified using the corresponding artifact public key.`,
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil {
|
||||
return fmt.Errorf("failed to sign artifact: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var verifyArtifactCmd = &cobra.Command{
|
||||
Use: "verify-artifact",
|
||||
Short: "Verify an artifact signature using an artifact public key",
|
||||
Long: `Verify a software artifact signature using the artifact's public key.`,
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil {
|
||||
return fmt.Errorf("failed to verify artifact: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var verifyArtifactKeyCmd = &cobra.Command{
|
||||
Use: "verify-artifact-key",
|
||||
Short: "Verify an artifact public key was signed by a root key",
|
||||
Long: `Verify that an artifact public key (or bundle) was properly signed by a root key.
|
||||
This validates the chain of trust from the root key to the artifact key.`,
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil {
|
||||
return fmt.Errorf("failed to verify artifact key: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(signArtifactCmd)
|
||||
rootCmd.AddCommand(verifyArtifactCmd)
|
||||
rootCmd.AddCommand(verifyArtifactKeyCmd)
|
||||
|
||||
signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey))
|
||||
signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed")
|
||||
|
||||
// artifact-file is required, but artifact-key-file can come from env var
|
||||
if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-file as required: %w", err))
|
||||
}
|
||||
|
||||
verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file")
|
||||
verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified")
|
||||
verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file")
|
||||
|
||||
if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err))
|
||||
}
|
||||
if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-file as required: %w", err))
|
||||
}
|
||||
if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil {
|
||||
panic(fmt.Errorf("mark signature-file as required: %w", err))
|
||||
}
|
||||
|
||||
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle")
|
||||
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle")
|
||||
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file")
|
||||
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)")
|
||||
|
||||
if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark artifact-key-file as required: %w", err))
|
||||
}
|
||||
if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil {
|
||||
panic(fmt.Errorf("mark root-key-file as required: %w", err))
|
||||
}
|
||||
if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil {
|
||||
panic(fmt.Errorf("mark signature-file as required: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error {
|
||||
cmd.Println("🖋️ Signing artifact...")
|
||||
|
||||
// Load private key from env var or file
|
||||
var privKeyPEM []byte
|
||||
var err error
|
||||
|
||||
if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" {
|
||||
// Use key from environment variable
|
||||
privKeyPEM = []byte(envKey)
|
||||
} else if privKeyFile != "" {
|
||||
// Fall back to file
|
||||
privKeyPEM, err = os.ReadFile(privKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read private key file: %w", err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey)
|
||||
}
|
||||
|
||||
privateKey, err := reposign.ParseArtifactKey(privKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse artifact private key: %w", err)
|
||||
}
|
||||
|
||||
artifactData, err := os.ReadFile(artifactFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read artifact file: %w", err)
|
||||
}
|
||||
|
||||
signature, err := reposign.SignData(privateKey, artifactData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sign artifact: %w", err)
|
||||
}
|
||||
|
||||
sigFile := artifactFile + ".sig"
|
||||
if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil {
|
||||
return fmt.Errorf("write signature file (%s): %w", sigFile, err)
|
||||
}
|
||||
|
||||
cmd.Printf("✅ Artifact signed successfully.\n")
|
||||
cmd.Printf("Signature file: %s\n", sigFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error {
|
||||
cmd.Println("🔍 Verifying artifact...")
|
||||
|
||||
// Read artifact public key
|
||||
pubKeyPEM, err := os.ReadFile(pubKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read public key file: %w", err)
|
||||
}
|
||||
|
||||
publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse artifact public key: %w", err)
|
||||
}
|
||||
|
||||
// Read artifact data
|
||||
artifactData, err := os.ReadFile(artifactFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read artifact file: %w", err)
|
||||
}
|
||||
|
||||
// Read signature
|
||||
sigBytes, err := os.ReadFile(signatureFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read signature file: %w", err)
|
||||
}
|
||||
|
||||
signature, err := reposign.ParseSignature(sigBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse signature: %w", err)
|
||||
}
|
||||
|
||||
// Validate artifact
|
||||
if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil {
|
||||
return fmt.Errorf("artifact verification failed: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("✅ Artifact signature is valid")
|
||||
cmd.Printf("Artifact: %s\n", artifactFile)
|
||||
cmd.Printf("Signed by key: %s\n", signature.KeyID)
|
||||
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error {
|
||||
cmd.Println("🔍 Verifying artifact key...")
|
||||
|
||||
// Read artifact key data
|
||||
artifactKeyData, err := os.ReadFile(artifactKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read artifact key file: %w", err)
|
||||
}
|
||||
|
||||
// Read root public key(s)
|
||||
rootKeyData, err := os.ReadFile(rootKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read root key file: %w", err)
|
||||
}
|
||||
|
||||
rootPublicKeys, err := parseRootPublicKeys(rootKeyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse root public key(s): %w", err)
|
||||
}
|
||||
|
||||
// Read signature
|
||||
sigBytes, err := os.ReadFile(signatureFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read signature file: %w", err)
|
||||
}
|
||||
|
||||
signature, err := reposign.ParseSignature(sigBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse signature: %w", err)
|
||||
}
|
||||
|
||||
// Read optional revocation list
|
||||
var revocationList *reposign.RevocationList
|
||||
if revocationFile != "" {
|
||||
revData, err := os.ReadFile(revocationFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read revocation file: %w", err)
|
||||
}
|
||||
|
||||
revocationList, err = reposign.ParseRevocationList(revData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse revocation list: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate artifact key(s)
|
||||
validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList)
|
||||
if err != nil {
|
||||
return fmt.Errorf("artifact key verification failed: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("✅ Artifact key(s) verified successfully")
|
||||
cmd.Printf("Signed by root key: %s\n", signature.KeyID)
|
||||
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
|
||||
cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys))
|
||||
for i, key := range validKeys {
|
||||
cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID)
|
||||
cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST"))
|
||||
if !key.Metadata.ExpiresAt.IsZero() {
|
||||
cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST"))
|
||||
} else {
|
||||
cmd.Printf(" Expires: Never\n")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseRootPublicKeys parses a root public key from PEM data
|
||||
func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) {
|
||||
key, err := reposign.ParseRootPublicKey(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []reposign.PublicKey{key}, nil
|
||||
}
|
||||
21
client/cmd/signer/main.go
Normal file
21
client/cmd/signer/main.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "signer",
|
||||
Short: "A CLI tool for managing cryptographic keys and artifacts",
|
||||
Long: `signer is a command-line tool that helps you manage
|
||||
root keys, artifact keys, and revocation lists securely.`,
|
||||
}
|
||||
|
||||
func main() {
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
rootCmd.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
220
client/cmd/signer/revocation.go
Normal file
220
client/cmd/signer/revocation.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year
|
||||
)
|
||||
|
||||
var (
|
||||
keyID string
|
||||
revocationListFile string
|
||||
privateRootKeyFile string
|
||||
publicRootKeyFile string
|
||||
signatureFile string
|
||||
expirationDuration time.Duration
|
||||
)
|
||||
|
||||
var createRevocationListCmd = &cobra.Command{
|
||||
Use: "create-revocation-list",
|
||||
Short: "Create a new revocation list signed by the private root key",
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile)
|
||||
},
|
||||
}
|
||||
|
||||
var extendRevocationListCmd = &cobra.Command{
|
||||
Use: "extend-revocation-list",
|
||||
Short: "Extend an existing revocation list with a given key ID",
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile)
|
||||
},
|
||||
}
|
||||
|
||||
var verifyRevocationListCmd = &cobra.Command{
|
||||
Use: "verify-revocation-list",
|
||||
Short: "Verify a revocation list signature using the public root key",
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile)
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(createRevocationListCmd)
|
||||
rootCmd.AddCommand(extendRevocationListCmd)
|
||||
rootCmd.AddCommand(verifyRevocationListCmd)
|
||||
|
||||
createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
|
||||
createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
|
||||
createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
|
||||
if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for")
|
||||
extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
|
||||
extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
|
||||
extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
|
||||
if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file")
|
||||
verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file")
|
||||
verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file")
|
||||
if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error {
|
||||
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read private root key file: %w", err)
|
||||
}
|
||||
|
||||
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||
}
|
||||
|
||||
rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create revocation list: %w", err)
|
||||
}
|
||||
|
||||
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil {
|
||||
return fmt.Errorf("failed to write output files: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("✅ Revocation list created successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error {
|
||||
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read private root key file: %w", err)
|
||||
}
|
||||
|
||||
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||
}
|
||||
|
||||
rlBytes, err := os.ReadFile(revocationListFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read revocation list file: %w", err)
|
||||
}
|
||||
|
||||
rl, err := reposign.ParseRevocationList(rlBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse revocation list: %w", err)
|
||||
}
|
||||
|
||||
kid, err := reposign.ParseKeyID(keyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid key ID: %w", err)
|
||||
}
|
||||
|
||||
newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extend revocation list: %w", err)
|
||||
}
|
||||
|
||||
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil {
|
||||
return fmt.Errorf("failed to write output files: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("✅ Revocation list extended successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error {
|
||||
// Read revocation list file
|
||||
rlBytes, err := os.ReadFile(revocationListFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read revocation list file: %w", err)
|
||||
}
|
||||
|
||||
// Read signature file
|
||||
sigBytes, err := os.ReadFile(signatureFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read signature file: %w", err)
|
||||
}
|
||||
|
||||
// Read public root key file
|
||||
pubKeyPEM, err := os.ReadFile(publicRootKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read public root key file: %w", err)
|
||||
}
|
||||
|
||||
// Parse public root key
|
||||
publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public root key: %w", err)
|
||||
}
|
||||
|
||||
// Parse signature
|
||||
signature, err := reposign.ParseSignature(sigBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse signature: %w", err)
|
||||
}
|
||||
|
||||
// Validate revocation list
|
||||
rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to validate revocation list: %w", err)
|
||||
}
|
||||
|
||||
// Display results
|
||||
cmd.Println("✅ Revocation list signature is valid")
|
||||
cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339))
|
||||
cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339))
|
||||
cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked))
|
||||
|
||||
if len(rl.Revoked) > 0 {
|
||||
cmd.Println("\nRevoked Keys:")
|
||||
for keyID, revokedTime := range rl.Revoked {
|
||||
cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error {
|
||||
if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil {
|
||||
return fmt.Errorf("failed to write revocation list file: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil {
|
||||
return fmt.Errorf("failed to write signature file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
74
client/cmd/signer/rootkey.go
Normal file
74
client/cmd/signer/rootkey.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
var (
|
||||
privKeyFile string
|
||||
pubKeyFile string
|
||||
rootExpiration time.Duration
|
||||
)
|
||||
|
||||
var createRootKeyCmd = &cobra.Command{
|
||||
Use: "create-root-key",
|
||||
Short: "Create a new root key pair",
|
||||
Long: `Create a new root key pair and specify an expiration time for it.`,
|
||||
SilenceUsage: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Validate expiration
|
||||
if rootExpiration <= 0 {
|
||||
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
|
||||
}
|
||||
|
||||
// Run main logic
|
||||
if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil {
|
||||
return fmt.Errorf("failed to generate root key: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(createRootKeyCmd)
|
||||
createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file")
|
||||
createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file")
|
||||
createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)")
|
||||
|
||||
if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error {
|
||||
rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate root key: %w", err)
|
||||
}
|
||||
|
||||
// Write private key
|
||||
if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("write private key file (%s): %w", privKeyFile, err)
|
||||
}
|
||||
|
||||
// Write public key
|
||||
if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err)
|
||||
}
|
||||
|
||||
cmd.Printf("%s\n\n", rk.String())
|
||||
cmd.Printf("✅ Root key pair generated successfully.\n")
|
||||
return nil
|
||||
}
|
||||
@@ -51,6 +51,7 @@ var (
|
||||
identityFile string
|
||||
skipCachedToken bool
|
||||
requestPTY bool
|
||||
sshNoBrowser bool
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -81,6 +82,7 @@ func init() {
|
||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
|
||||
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
|
||||
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
|
||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||
@@ -185,6 +187,21 @@ func getEnvOrDefault(flagName, defaultValue string) string {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes
|
||||
func getBoolEnvOrDefault(flagName string, defaultValue bool) bool {
|
||||
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||
if parsed, err := strconv.ParseBool(envValue); err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// resetSSHGlobals sets SSH globals to their default values
|
||||
func resetSSHGlobals() {
|
||||
port = sshserver.DefaultSSHPort
|
||||
@@ -196,6 +213,7 @@ func resetSSHGlobals() {
|
||||
strictHostKeyChecking = true
|
||||
knownHostsFile = ""
|
||||
identityFile = ""
|
||||
sshNoBrowser = false
|
||||
}
|
||||
|
||||
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
||||
@@ -370,6 +388,7 @@ type sshFlags struct {
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
SkipCachedToken bool
|
||||
NoBrowser bool
|
||||
ConfigPath string
|
||||
LogLevel string
|
||||
LocalForwards []string
|
||||
@@ -381,6 +400,7 @@ type sshFlags struct {
|
||||
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||
|
||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||
fs.SetOutput(nil)
|
||||
@@ -401,6 +421,7 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
|
||||
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc)
|
||||
|
||||
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
|
||||
@@ -449,6 +470,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
knownHostsFile = flags.KnownHostsFile
|
||||
identityFile = flags.IdentityFile
|
||||
skipCachedToken = flags.SkipCachedToken
|
||||
sshNoBrowser = flags.NoBrowser
|
||||
|
||||
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||
configPath = flags.ConfigPath
|
||||
@@ -508,6 +530,7 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||
DaemonAddr: daemonAddr,
|
||||
SkipCachedToken: skipCachedToken,
|
||||
InsecureSkipVerify: !strictHostKeyChecking,
|
||||
NoBrowser: sshNoBrowser,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -749,7 +772,9 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
|
||||
proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
if err := util.InitLog(proxyLogLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
@@ -761,7 +786,15 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("invalid port: %s", portStr)
|
||||
}
|
||||
|
||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
|
||||
// Check env var for browser setting since this command is invoked via SSH ProxyCommand
|
||||
// where command-line flags cannot be passed. Default is to open browser.
|
||||
noBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
|
||||
var browserOpener func(string) error
|
||||
if !noBrowser {
|
||||
browserOpener = util.OpenBrowser
|
||||
}
|
||||
|
||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create SSH proxy: %w", err)
|
||||
}
|
||||
@@ -788,7 +821,8 @@ var sshDetectCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||
if err := util.InitLog(logLevel, "console"); err != nil {
|
||||
detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
if err := util.InitLog(detectLogLevel, "console"); err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
@@ -797,15 +831,21 @@ func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
log.Debugf("invalid port %q: %v", portStr, err)
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
|
||||
ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout)
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
if err != nil {
|
||||
log.Debugf("SSH server detection failed: %v", err)
|
||||
cancel()
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
cancel()
|
||||
os.Exit(serverType.ExitCode())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
@@ -24,8 +26,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -116,15 +116,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), config)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil)
|
||||
|
||||
13
client/cmd/update.go
Normal file
13
client/cmd/update.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var updateCmd *cobra.Command
|
||||
|
||||
func isUpdateBinary() bool {
|
||||
return false
|
||||
}
|
||||
75
client/cmd/update_supported.go
Normal file
75
client/cmd/update_supported.go
Normal file
@@ -0,0 +1,75 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
updateCmd = &cobra.Command{
|
||||
Use: "update",
|
||||
Short: "Update the NetBird client application",
|
||||
RunE: updateFunc,
|
||||
}
|
||||
|
||||
tempDirFlag string
|
||||
installerFile string
|
||||
serviceDirFlag string
|
||||
dryRunFlag bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir")
|
||||
updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file")
|
||||
updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory")
|
||||
updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes")
|
||||
}
|
||||
|
||||
// isUpdateBinary checks if the current executable is named "update" or "update.exe"
|
||||
func isUpdateBinary() bool {
|
||||
// Remove extension for cross-platform compatibility
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
baseName := filepath.Base(execPath)
|
||||
name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
|
||||
|
||||
return name == installer.UpdaterBinaryNameWithoutExtension()
|
||||
}
|
||||
|
||||
func updateFunc(cmd *cobra.Command, args []string) error {
|
||||
if err := setupLogToFile(tempDirFlag); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("updater started: %s", serviceDirFlag)
|
||||
updater := installer.NewWithDir(tempDirFlag)
|
||||
if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil {
|
||||
log.Errorf("failed to update application: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupLogToFile(dir string) error {
|
||||
logFile := filepath.Join(dir, installer.LogFile)
|
||||
|
||||
if _, err := os.Stat(logFile); err == nil {
|
||||
if err := os.Remove(logFile); err != nil {
|
||||
log.Errorf("failed to remove existing log file: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return util.InitLog(logLevel, util.LogConsole, logFile)
|
||||
}
|
||||
@@ -173,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder)
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
|
||||
@@ -27,7 +27,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
tableNat = "nat"
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
tableRaw = "raw"
|
||||
tableSecurity = "security"
|
||||
|
||||
chainNameNatPrerouting = "PREROUTING"
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
@@ -91,11 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
if errors.Is(err, errFilterTableNotFound) {
|
||||
log.Warnf("table 'filter' not found for forward rules")
|
||||
} else {
|
||||
return nil, fmt.Errorf("load filter table: %w", err)
|
||||
}
|
||||
log.Debugf("ip filter table not found: %v", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
@@ -175,7 +175,7 @@ func (r *router) removeNatPreroutingRules() error {
|
||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to list tables: %v", err)
|
||||
return nil, fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
@@ -187,14 +187,39 @@ func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func hookName(hook *nftables.ChainHook) string {
|
||||
if hook == nil {
|
||||
return "unknown"
|
||||
}
|
||||
switch *hook {
|
||||
case *nftables.ChainHookForward:
|
||||
return chainNameForward
|
||||
case *nftables.ChainHookInput:
|
||||
return chainNameInput
|
||||
default:
|
||||
return fmt.Sprintf("hook(%d)", *hook)
|
||||
}
|
||||
}
|
||||
|
||||
func familyName(family nftables.TableFamily) string {
|
||||
switch family {
|
||||
case nftables.TableFamilyIPv4:
|
||||
return "ip"
|
||||
case nftables.TableFamilyIPv6:
|
||||
return "ip6"
|
||||
case nftables.TableFamilyINet:
|
||||
return "inet"
|
||||
default:
|
||||
return fmt.Sprintf("family(%d)", family)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) createContainers() error {
|
||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
prio := *nftables.ChainPriorityNATSource - 1
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
@@ -236,9 +261,12 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
// Add the single NAT rule that matches on mark
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add single nat rule: %v", err)
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
r.addPostroutingRules()
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("initialize tables: %v", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
@@ -250,11 +278,7 @@ func (r *router) createContainers() error {
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("initialize tables: %v", err)
|
||||
log.Errorf("failed to refresh rules: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -695,7 +719,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
// addPostroutingRules adds the masquerade rules
|
||||
func (r *router) addPostroutingRules() error {
|
||||
func (r *router) addPostroutingRules() {
|
||||
// First masquerade rule for traffic coming in from WireGuard interface
|
||||
exprs := []expr.Any{
|
||||
// Match on the first fwmark
|
||||
@@ -761,8 +785,6 @@ func (r *router) addPostroutingRules() error {
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs2,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
@@ -839,7 +861,7 @@ func (r *router) addMSSClampingRules() error {
|
||||
Exprs: exprsOut,
|
||||
})
|
||||
|
||||
return nil
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
@@ -939,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
||||
func (r *router) acceptForwardRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.acceptFilterTableRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) acceptFilterTableRules() error {
|
||||
if r.filterTable == nil {
|
||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -953,11 +988,11 @@ func (r *router) acceptForwardRules() error {
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
// filter table exists but iptables is not
|
||||
// iptables is not available but the filter table exists
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables()
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
|
||||
return r.acceptFilterRulesIptables(ipt)
|
||||
@@ -968,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
|
||||
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables forward rule: %v", rule)
|
||||
}
|
||||
@@ -976,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
|
||||
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables input rule: %v", inputRule)
|
||||
}
|
||||
@@ -996,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string {
|
||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
func (r *router) acceptFilterRulesNftables() error {
|
||||
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
||||
// This is used when iptables is not available.
|
||||
func (r *router) acceptFilterRulesNftables(table *nftables.Table) error {
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
forwardChain := &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
}
|
||||
r.insertForwardAcceptRules(forwardChain, intf)
|
||||
|
||||
inputChain := &nftables.Chain{
|
||||
Name: chainNameInput,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
}
|
||||
r.insertInputAcceptRule(inputChain, intf)
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
||||
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
||||
func (r *router) acceptExternalChainsRules() error {
|
||||
chains := r.findExternalChains()
|
||||
if len(chains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Hooknum == nil {
|
||||
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||
|
||||
switch *chain.Hooknum {
|
||||
case *nftables.ChainHookForward:
|
||||
r.insertForwardAcceptRules(chain, intf)
|
||||
case *nftables.ChainHookInput:
|
||||
r.insertInputAcceptRule(chain, intf)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush external chain rules: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||
iifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
@@ -1030,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error {
|
||||
Data: intf,
|
||||
},
|
||||
}
|
||||
|
||||
oifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
}
|
||||
r.conn.InsertRule(oifRule)
|
||||
}
|
||||
|
||||
func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||
inputRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameInput,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
@@ -1067,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error {
|
||||
UserData: []byte(userDataAcceptInputRule),
|
||||
}
|
||||
r.conn.InsertRule(inputRule)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptFilterRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeFilterTableRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *router) removeFilterTableRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
return r.removeAcceptFilterRulesNftables()
|
||||
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
|
||||
return r.removeAcceptFilterRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptFilterRulesNftables() error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name {
|
||||
if chain.Table.Name != table.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1100,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error {
|
||||
continue
|
||||
}
|
||||
|
||||
rules, err := r.conn.GetRules(r.filterTable, chain)
|
||||
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
||||
rules, err := r.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeExternalChainsRules removes our accept rules from all external chains.
|
||||
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
||||
// ensuring cleanup works even after a crash or if chains changed.
|
||||
func (r *router) removeExternalChainsRules() error {
|
||||
chains := r.findExternalChains()
|
||||
if len(chains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
||||
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
||||
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
||||
func (r *router) findExternalChains() []*nftables.Chain {
|
||||
var chains []*nftables.Chain
|
||||
|
||||
families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet}
|
||||
|
||||
for _, family := range families {
|
||||
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules: %v", err)
|
||||
log.Debugf("list chains for family %d: %v", family, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
for _, chain := range allChains {
|
||||
if r.isExternalChain(chain) {
|
||||
chains = append(chains, chain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
return chains
|
||||
}
|
||||
|
||||
func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
||||
return false
|
||||
}
|
||||
|
||||
return nil
|
||||
// Skip all iptables-managed tables in the ip family
|
||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
}
|
||||
|
||||
if chain.Type != nftables.ChainTypeFilter {
|
||||
return false
|
||||
}
|
||||
|
||||
if chain.Hooknum == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
||||
}
|
||||
|
||||
func isIptablesTable(name string) bool {
|
||||
switch name {
|
||||
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
@@ -1128,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
@@ -1196,7 +1358,7 @@ func (r *router) refreshRulesMap() error {
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(" unable to list rules: %v", err)
|
||||
return fmt.Errorf("list rules: %w", err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -19,11 +20,12 @@ import (
|
||||
|
||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||
type WGTunDevice struct {
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu uint16
|
||||
iceBind *bind.ICEBind
|
||||
address wgaddr.Address
|
||||
port int
|
||||
key string
|
||||
mtu uint16
|
||||
iceBind *bind.ICEBind
|
||||
// todo: review if we can eliminate the TunAdapter
|
||||
tunAdapter TunAdapter
|
||||
disableDNS bool
|
||||
|
||||
@@ -32,17 +34,19 @@ type WGTunDevice struct {
|
||||
filteredDevice *FilteredDevice
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
configurer WGConfigurer
|
||||
renewableTun *RenewableTUN
|
||||
}
|
||||
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu uint16, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||
return &WGTunDevice{
|
||||
address: address,
|
||||
port: port,
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
iceBind: iceBind,
|
||||
tunAdapter: tunAdapter,
|
||||
disableDNS: disableDNS,
|
||||
address: address,
|
||||
port: port,
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
iceBind: iceBind,
|
||||
tunAdapter: tunAdapter,
|
||||
disableDNS: disableDNS,
|
||||
renewableTun: NewRenewableTUN(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,14 +69,17 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tunDevice, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||
unmonitoredTUN, name, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
log.Errorf("failed to create Android interface: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t.renewableTun.AddDevice(unmonitoredTUN)
|
||||
|
||||
t.name = name
|
||||
t.filteredDevice = newDeviceFilter(tunDevice)
|
||||
t.filteredDevice = newDeviceFilter(t.renewableTun)
|
||||
|
||||
log.Debugf("attaching to interface %v", name)
|
||||
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "))
|
||||
@@ -104,6 +111,23 @@ func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) RenewTun(fd int) error {
|
||||
if t.device == nil {
|
||||
return fmt.Errorf("device not initialized")
|
||||
}
|
||||
|
||||
unmonitoredTUN, _, err := tun.CreateUnmonitoredTUNFromFD(fd)
|
||||
if err != nil {
|
||||
_ = unix.Close(fd)
|
||||
log.Errorf("failed to renew Android interface: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
t.renewableTun.AddDevice(unmonitoredTUN)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||
// todo implement
|
||||
return nil
|
||||
|
||||
@@ -2,6 +2,13 @@
|
||||
|
||||
package device
|
||||
|
||||
import "fmt"
|
||||
|
||||
func (t *TunNetstackDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
|
||||
return t.create()
|
||||
}
|
||||
|
||||
func (t *TunNetstackDevice) RenewTun(fd int) error {
|
||||
// Doesn't make sense in Android for Netstack.
|
||||
return fmt.Errorf("this function has not been implemented in Netstack for Android")
|
||||
}
|
||||
|
||||
309
client/iface/device/renewable_tun.go
Normal file
309
client/iface/device/renewable_tun.go
Normal file
@@ -0,0 +1,309 @@
|
||||
//go:build android
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
// closeAwareDevice wraps a tun.Device along with a flag
|
||||
// indicating whether its Close method was called.
|
||||
//
|
||||
// It also redirects tun.Device's Events() to a separate goroutine
|
||||
// and closes it when Close is called.
|
||||
//
|
||||
// The WaitGroup and CloseOnce fields are used to ensure that the
|
||||
// goroutine is awaited and closed only once.
|
||||
type closeAwareDevice struct {
|
||||
isClosed atomic.Bool
|
||||
tun.Device
|
||||
closeEventCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func newClosableDevice(tunDevice tun.Device) *closeAwareDevice {
|
||||
return &closeAwareDevice{
|
||||
Device: tunDevice,
|
||||
isClosed: atomic.Bool{},
|
||||
closeEventCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// redirectEvents redirects the Events() method of the underlying tun.Device
|
||||
// to the given channel (RenewableTUN's events channel).
|
||||
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-c.Device.Events():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if ev == tun.EventDown {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case out <- ev:
|
||||
case <-c.closeEventCh:
|
||||
return
|
||||
}
|
||||
case <-c.closeEventCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Close calls the underlying Device's Close method
|
||||
// after setting isClosed to true.
|
||||
func (c *closeAwareDevice) Close() (err error) {
|
||||
c.closeOnce.Do(func() {
|
||||
c.isClosed.Store(true)
|
||||
close(c.closeEventCh)
|
||||
err = c.Device.Close()
|
||||
c.wg.Wait()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *closeAwareDevice) IsClosed() bool {
|
||||
return c.isClosed.Load()
|
||||
}
|
||||
|
||||
type RenewableTUN struct {
|
||||
devices []*closeAwareDevice
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
events chan tun.Event
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func NewRenewableTUN() *RenewableTUN {
|
||||
r := &RenewableTUN{
|
||||
devices: make([]*closeAwareDevice, 0),
|
||||
mu: sync.Mutex{},
|
||||
events: make(chan tun.Event, 16),
|
||||
}
|
||||
r.cond = sync.NewCond(&r.mu)
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) File() *os.File {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
if !r.waitForDevice() {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
file := dev.File()
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return file
|
||||
}
|
||||
}
|
||||
|
||||
// Read reads from an underlying tun.Device kept in the r.devices slice.
|
||||
// If no device is available, it waits for one to be added via AddDevice().
|
||||
//
|
||||
// On error, it retries reading from the newest device instead of returning the error
|
||||
// if the device is closed; if not, it propagates the error.
|
||||
func (r *RenewableTUN) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
// wait until AddDevice() signals a new device via cond.Broadcast()
|
||||
if !r.waitForDevice() { // returns false if the renewable TUN itself is closed
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
n, err = dev.Read(bufs, sizes, offset)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// swap in progress; retry on the newest instead of returning the error
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return n, err // propagate non-swap error
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes to underlying tun.Device kept in the r.devices slice.
|
||||
// If no device is available, it waits for one to be added via AddDevice().
|
||||
//
|
||||
// On error, it retries writing to the newest device instead of returning the error
|
||||
// if the device is closed; if not, it propagates the error.
|
||||
func (r *RenewableTUN) Write(bufs [][]byte, offset int) (int, error) {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
if !r.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := dev.Write(bufs, offset)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) MTU() (int, error) {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
if !r.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
mtu, err := dev.MTU()
|
||||
if err == nil {
|
||||
return mtu, nil
|
||||
}
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) Name() (string, error) {
|
||||
for {
|
||||
dev := r.peekLast()
|
||||
if dev == nil {
|
||||
if !r.waitForDevice() {
|
||||
return "", io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
name, err := dev.Name()
|
||||
if err == nil {
|
||||
return name, nil
|
||||
}
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Events returns a channel that is fed events from the underlying tun.Device's events channel
|
||||
// once it is added.
|
||||
func (r *RenewableTUN) Events() <-chan tun.Event {
|
||||
return r.events
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) Close() error {
|
||||
// Attempts to set the RenewableTUN closed flag to true.
|
||||
// If it's already true, returns immediately.
|
||||
if !r.closed.CompareAndSwap(false, true) {
|
||||
return nil // already closed: idempotent
|
||||
}
|
||||
r.mu.Lock()
|
||||
devices := r.devices
|
||||
r.devices = nil
|
||||
r.cond.Broadcast()
|
||||
r.mu.Unlock()
|
||||
|
||||
var lastErr error
|
||||
|
||||
log.Debugf("closing %d devices", len(devices))
|
||||
for _, device := range devices {
|
||||
if err := device.Close(); err != nil {
|
||||
log.Debugf("error closing a device: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
close(r.events)
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) AddDevice(device tun.Device) {
|
||||
r.mu.Lock()
|
||||
if r.closed.Load() {
|
||||
r.mu.Unlock()
|
||||
_ = device.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var toClose *closeAwareDevice
|
||||
if len(r.devices) > 0 {
|
||||
toClose = r.devices[len(r.devices)-1]
|
||||
}
|
||||
|
||||
cad := newClosableDevice(device)
|
||||
cad.redirectEvents(r.events)
|
||||
|
||||
r.devices = []*closeAwareDevice{cad}
|
||||
r.cond.Broadcast()
|
||||
|
||||
r.mu.Unlock()
|
||||
|
||||
if toClose != nil {
|
||||
if err := toClose.Close(); err != nil {
|
||||
log.Debugf("error closing last device: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) waitForDevice() bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for len(r.devices) == 0 && !r.closed.Load() {
|
||||
r.cond.Wait()
|
||||
}
|
||||
return !r.closed.Load()
|
||||
}
|
||||
|
||||
func (r *RenewableTUN) peekLast() *closeAwareDevice {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if len(r.devices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.devices[len(r.devices)-1]
|
||||
}
|
||||
@@ -21,5 +21,6 @@ type WGTunDevice interface {
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
RenewTun(fd int) error
|
||||
GetICEBind() device.EndpointManager
|
||||
}
|
||||
|
||||
@@ -24,3 +24,7 @@ func (w *WGIface) Create() error {
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on non mobile")
|
||||
}
|
||||
|
||||
func (w *WGIface) RenewTun(fd int) error {
|
||||
return fmt.Errorf("this function has not been implemented on non-android")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||
// Will reuse an existing one.
|
||||
// todo: review does this function really necessary or can we merge it with iOS
|
||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
@@ -22,3 +23,9 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
|
||||
func (w *WGIface) Create() error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
||||
|
||||
func (w *WGIface) RenewTun(fd int) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.tun.RenewTun(fd)
|
||||
}
|
||||
|
||||
@@ -39,3 +39,7 @@ func (w *WGIface) Create() error {
|
||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
||||
return fmt.Errorf("this function has not implemented on this platform")
|
||||
}
|
||||
|
||||
func (w *WGIface) RenewTun(fd int) error {
|
||||
return fmt.Errorf("this function has not been implemented on this platform")
|
||||
}
|
||||
|
||||
@@ -60,14 +60,19 @@ func (t TokenInfo) GetTokenToUse() string {
|
||||
return t.AccessToken
|
||||
}
|
||||
|
||||
func shouldUseDeviceFlow(force bool, isUnixDesktopClient bool) bool {
|
||||
return force || (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient
|
||||
}
|
||||
|
||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
||||
//
|
||||
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||
// forceDeviceCodeFlow can be used to skip PKCE and go directly to Device Code Flow (e.g., for Android TV)
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, forceDeviceCodeFlow bool, hint string) (OAuthFlow, error) {
|
||||
if shouldUseDeviceFlow(forceDeviceCodeFlow, isUnixDesktopClient) {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/templates"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
)
|
||||
|
||||
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
||||
@@ -46,9 +48,10 @@ type PKCEAuthorizationFlow struct {
|
||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||
var availableRedirectURL string
|
||||
|
||||
// find the first available redirect URL
|
||||
excludedRanges := getSystemExcludedPortRanges()
|
||||
|
||||
for _, redirectURL := range config.RedirectURLs {
|
||||
if !isRedirectURLPortUsed(redirectURL) {
|
||||
if !isRedirectURLPortUsed(redirectURL, excludedRanges) {
|
||||
availableRedirectURL = redirectURL
|
||||
break
|
||||
}
|
||||
@@ -102,10 +105,10 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||
}
|
||||
if !p.providerConfig.DisablePromptLogin {
|
||||
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
||||
switch p.providerConfig.LoginFlag {
|
||||
case common.LoginFlagPromptLogin:
|
||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||
}
|
||||
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
||||
case common.LoginFlagMaxAge0:
|
||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||
}
|
||||
}
|
||||
@@ -282,15 +285,22 @@ func createCodeChallenge(codeVerifier string) string {
|
||||
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
||||
}
|
||||
|
||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use.
|
||||
func isRedirectURLPortUsed(redirectURL string) bool {
|
||||
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows.
|
||||
func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool {
|
||||
parsedURL, err := url.Parse(redirectURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse redirect URL: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%s", parsedURL.Port())
|
||||
port := parsedURL.Port()
|
||||
|
||||
if isPortInExcludedRange(port, excludedRanges) {
|
||||
log.Warnf("port %s is in Windows excluded port range, skipping", port)
|
||||
return true
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%s", port)
|
||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||
if err != nil {
|
||||
return false
|
||||
@@ -304,6 +314,33 @@ func isRedirectURLPortUsed(redirectURL string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// excludedPortRange represents a range of excluded ports.
|
||||
type excludedPortRange struct {
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
// isPortInExcludedRange checks if the given port is in any of the excluded ranges.
|
||||
func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool {
|
||||
if len(excludedRanges) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
log.Debugf("invalid port number %s: %v", port, err)
|
||||
return false
|
||||
}
|
||||
|
||||
for _, r := range excludedRanges {
|
||||
if portNum >= r.start && portNum <= r.end {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
||||
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
||||
if err != nil {
|
||||
|
||||
8
client/internal/auth/pkce_flow_other.go
Normal file
8
client/internal/auth/pkce_flow_other.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !windows
|
||||
|
||||
package auth
|
||||
|
||||
// getSystemExcludedPortRanges returns nil on non-Windows platforms.
|
||||
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||
return nil
|
||||
}
|
||||
@@ -2,8 +2,11 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
@@ -20,22 +23,28 @@ func TestPromptLogin(t *testing.T) {
|
||||
name string
|
||||
loginFlag mgm.LoginFlag
|
||||
disablePromptLogin bool
|
||||
expect string
|
||||
expectContains []string
|
||||
}{
|
||||
{
|
||||
name: "Prompt login",
|
||||
loginFlag: mgm.LoginFlagPrompt,
|
||||
expect: promptLogin,
|
||||
name: "Prompt login",
|
||||
loginFlag: mgm.LoginFlagPromptLogin,
|
||||
expectContains: []string{promptLogin},
|
||||
},
|
||||
{
|
||||
name: "Max age 0 login",
|
||||
loginFlag: mgm.LoginFlagMaxAge0,
|
||||
expect: maxAge0,
|
||||
name: "Max age 0",
|
||||
loginFlag: mgm.LoginFlagMaxAge0,
|
||||
expectContains: []string{maxAge0},
|
||||
},
|
||||
{
|
||||
name: "Disable prompt login",
|
||||
loginFlag: mgm.LoginFlagPrompt,
|
||||
loginFlag: mgm.LoginFlagPromptLogin,
|
||||
disablePromptLogin: true,
|
||||
expectContains: []string{},
|
||||
},
|
||||
{
|
||||
name: "None flag should not add parameters",
|
||||
loginFlag: mgm.LoginFlagNone,
|
||||
expectContains: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -50,6 +59,7 @@ func TestPromptLogin(t *testing.T) {
|
||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||
UseIDToken: true,
|
||||
LoginFlag: tc.loginFlag,
|
||||
DisablePromptLogin: tc.disablePromptLogin,
|
||||
}
|
||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||
if err != nil {
|
||||
@@ -60,12 +70,153 @@ func TestPromptLogin(t *testing.T) {
|
||||
t.Fatalf("Failed to request auth info: %v", err)
|
||||
}
|
||||
|
||||
if !tc.disablePromptLogin {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
||||
} else {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
||||
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
||||
for _, expected := range tc.expectContains {
|
||||
require.Contains(t, authInfo.VerificationURIComplete, expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsPortInExcludedRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
port string
|
||||
excludedRanges []excludedPortRange
|
||||
expectedBlocked bool
|
||||
}{
|
||||
{
|
||||
name: "Port in excluded range",
|
||||
port: "8080",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Port at start of range",
|
||||
port: "8000",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Port at end of range",
|
||||
port: "8100",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Port before range",
|
||||
port: "7999",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Port after range",
|
||||
port: "8101",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Empty excluded ranges",
|
||||
port: "8080",
|
||||
excludedRanges: []excludedPortRange{},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Nil excluded ranges",
|
||||
port: "8080",
|
||||
excludedRanges: nil,
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Multiple ranges - port in second range",
|
||||
port: "9050",
|
||||
excludedRanges: []excludedPortRange{
|
||||
{start: 8000, end: 8100},
|
||||
{start: 9000, end: 9100},
|
||||
},
|
||||
expectedBlocked: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple ranges - port not in any range",
|
||||
port: "8500",
|
||||
excludedRanges: []excludedPortRange{
|
||||
{start: 8000, end: 8100},
|
||||
{start: 9000, end: 9100},
|
||||
},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid port string",
|
||||
port: "invalid",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "Empty port string",
|
||||
port: "",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedBlocked: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isPortInExcludedRange(tt.port, tt.excludedRanges)
|
||||
assert.Equal(t, tt.expectedBlocked, result, "Port exclusion check mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRedirectURLPortUsed(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
|
||||
usedPort := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
redirectURL string
|
||||
excludedRanges []excludedPortRange
|
||||
expectedUsed bool
|
||||
}{
|
||||
{
|
||||
name: "Port in excluded range",
|
||||
redirectURL: "http://127.0.0.1:8080/",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedUsed: true,
|
||||
},
|
||||
{
|
||||
name: "Port actually in use",
|
||||
redirectURL: fmt.Sprintf("http://127.0.0.1:%d/", usedPort),
|
||||
excludedRanges: nil,
|
||||
expectedUsed: true,
|
||||
},
|
||||
{
|
||||
name: "Port not in use and not excluded",
|
||||
redirectURL: "http://127.0.0.1:65432/",
|
||||
excludedRanges: nil,
|
||||
expectedUsed: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid URL without port",
|
||||
redirectURL: "not-a-valid-url",
|
||||
excludedRanges: nil,
|
||||
expectedUsed: false,
|
||||
},
|
||||
{
|
||||
name: "Port excluded even if not in use",
|
||||
redirectURL: "http://127.0.0.1:8050/",
|
||||
excludedRanges: []excludedPortRange{{start: 8000, end: 8100}},
|
||||
expectedUsed: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isRedirectURLPortUsed(tt.redirectURL, tt.excludedRanges)
|
||||
assert.Equal(t, tt.expectedUsed, result, "Port usage check mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
86
client/internal/auth/pkce_flow_windows.go
Normal file
86
client/internal/auth/pkce_flow_windows.go
Normal file
@@ -0,0 +1,86 @@
|
||||
//go:build windows
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// getSystemExcludedPortRanges retrieves the excluded port ranges from Windows using netsh.
|
||||
func getSystemExcludedPortRanges() []excludedPortRange {
|
||||
ranges, err := getExcludedPortRangesFromNetsh()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get Windows excluded port ranges: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return ranges
|
||||
}
|
||||
|
||||
// getExcludedPortRangesFromNetsh retrieves excluded port ranges using netsh command.
|
||||
func getExcludedPortRangesFromNetsh() ([]excludedPortRange, error) {
|
||||
cmd := exec.Command("netsh", "interface", "ipv4", "show", "excludedportrange", "protocol=tcp")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("netsh command: %w", err)
|
||||
}
|
||||
|
||||
return parseExcludedPortRanges(string(output))
|
||||
}
|
||||
|
||||
// parseExcludedPortRanges parses the output of the netsh command to extract port ranges.
|
||||
func parseExcludedPortRanges(output string) ([]excludedPortRange, error) {
|
||||
var ranges []excludedPortRange
|
||||
scanner := bufio.NewScanner(strings.NewReader(output))
|
||||
|
||||
foundHeader := false
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if strings.Contains(line, "Start Port") && strings.Contains(line, "End Port") {
|
||||
foundHeader = true
|
||||
continue
|
||||
}
|
||||
|
||||
if !foundHeader {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(line, "----------") {
|
||||
continue
|
||||
}
|
||||
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
startPort, err := strconv.Atoi(fields[0])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
endPort, err := strconv.Atoi(fields[1])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ranges = append(ranges, excludedPortRange{start: startPort, end: endPort})
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("scan output: %w", err)
|
||||
}
|
||||
|
||||
return ranges, nil
|
||||
}
|
||||
116
client/internal/auth/pkce_flow_windows_test.go
Normal file
116
client/internal/auth/pkce_flow_windows_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
//go:build windows
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
func TestParseExcludedPortRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
netshOutput string
|
||||
expectedRanges []excludedPortRange
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid netsh output with multiple ranges",
|
||||
netshOutput: `
|
||||
Protocol tcp Dynamic Port Range
|
||||
---------------------------------
|
||||
Start Port : 49152
|
||||
Number of Ports : 16384
|
||||
|
||||
Protocol tcp Excluded Port Ranges
|
||||
---------------------------------
|
||||
Start Port End Port
|
||||
---------- --------
|
||||
5357 5357 *
|
||||
50000 50059 *
|
||||
`,
|
||||
expectedRanges: []excludedPortRange{
|
||||
{start: 5357, end: 5357},
|
||||
{start: 50000, end: 50059},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty output",
|
||||
netshOutput: `
|
||||
Protocol tcp Dynamic Port Range
|
||||
---------------------------------
|
||||
Start Port : 49152
|
||||
Number of Ports : 16384
|
||||
`,
|
||||
expectedRanges: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Single range",
|
||||
netshOutput: `
|
||||
Protocol tcp Excluded Port Ranges
|
||||
---------------------------------
|
||||
Start Port End Port
|
||||
---------- --------
|
||||
8080 8090
|
||||
`,
|
||||
expectedRanges: []excludedPortRange{
|
||||
{start: 8080, end: 8090},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ranges, err := parseExcludedPortRanges(tt.netshOutput)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedRanges, ranges)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
||||
ranges := getSystemExcludedPortRanges()
|
||||
t.Logf("Found %d excluded port ranges on this system", len(ranges))
|
||||
|
||||
listener1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = listener1.Close()
|
||||
}()
|
||||
usedPort1 := listener1.Addr().(*net.TCPAddr).Port
|
||||
|
||||
availablePort := 65432
|
||||
|
||||
config := internal.PKCEAuthProviderConfig{
|
||||
ClientID: "test-client-id",
|
||||
Audience: "test-audience",
|
||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||
Scope: "openid email profile",
|
||||
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||
RedirectURLs: []string{
|
||||
fmt.Sprintf("http://127.0.0.1:%d/", usedPort1),
|
||||
fmt.Sprintf("http://127.0.0.1:%d/", availablePort),
|
||||
},
|
||||
UseIDToken: true,
|
||||
}
|
||||
|
||||
flow, err := NewPKCEAuthorizationFlow(config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, flow)
|
||||
assert.Contains(t, flow.oAuthConfig.RedirectURL, fmt.Sprintf(":%d", availablePort),
|
||||
"Should skip port in use and select available port")
|
||||
}
|
||||
@@ -24,10 +24,14 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
@@ -39,11 +43,13 @@ import (
|
||||
)
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
doInitialAutoUpdate bool
|
||||
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
|
||||
persistSyncResponse bool
|
||||
}
|
||||
@@ -52,13 +58,15 @@ func NewConnectClient(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
doInitalAutoUpdate bool,
|
||||
|
||||
) *ConnectClient {
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
engineMutex: sync.Mutex{},
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
doInitialAutoUpdate: doInitalAutoUpdate,
|
||||
engineMutex: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,6 +82,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
stateFilePath string,
|
||||
) error {
|
||||
// in case of non Android os these variables will be nil
|
||||
mobileDependency := MobileDependency{
|
||||
@@ -82,6 +91,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil)
|
||||
}
|
||||
@@ -160,6 +170,33 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return err
|
||||
}
|
||||
|
||||
var path string
|
||||
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
|
||||
// On mobile, use the provided state file path directly
|
||||
if !fileExists(mobileDependency.StateFilePath) {
|
||||
if err := createFile(mobileDependency.StateFilePath); err != nil {
|
||||
log.Errorf("failed to create state file: %v", err)
|
||||
// we are not exiting as we can run without the state manager
|
||||
}
|
||||
}
|
||||
path = mobileDependency.StateFilePath
|
||||
} else {
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
path = sm.GetStatePath()
|
||||
}
|
||||
stateManager := statemanager.New(path)
|
||||
stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||
|
||||
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
|
||||
if err == nil {
|
||||
updateManager.CheckUpdateSuccess(c.ctx)
|
||||
|
||||
inst := installer.New()
|
||||
if err := inst.CleanUpInstallerFiles(); err != nil {
|
||||
log.Errorf("failed to clean up temporary installer file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
defer c.statusRecorder.ClientStop()
|
||||
operation := func() error {
|
||||
// if context cancelled we not start new backoff cycle
|
||||
@@ -271,15 +308,25 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
||||
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
|
||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
c.engine = engine
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||
if err := engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
|
||||
// AutoUpdate will be true when the user click on "Connect" menu on the UI
|
||||
if c.doInitialAutoUpdate {
|
||||
log.Infof("start engine by ui, run auto-update check")
|
||||
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
|
||||
c.doInitialAutoUpdate = false
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
@@ -291,12 +338,14 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
<-engineCtx.Done()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
engine := c.engine
|
||||
c.engine = nil
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if engine != nil && engine.wgInterface != nil {
|
||||
// todo: consider to remove this condition. Is not thread safe.
|
||||
// We should always call Stop(), but we need to verify that it is idempotent
|
||||
if engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@@ -56,6 +57,7 @@ block.prof: Block profiling information.
|
||||
heap.prof: Heap profiling information (snapshot of memory allocations).
|
||||
allocs.prof: Allocations profiling information.
|
||||
threadcreate.prof: Thread creation profiling information.
|
||||
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
||||
|
||||
|
||||
Anonymization Process
|
||||
@@ -109,6 +111,9 @@ go tool pprof -http=:8088 heap.prof
|
||||
|
||||
This will open a web browser tab with the profiling information.
|
||||
|
||||
Stack Trace
|
||||
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
|
||||
|
||||
Routes
|
||||
The routes.txt file contains detailed routing table information in a tabular format:
|
||||
|
||||
@@ -327,6 +332,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add profiles to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addStackTrace(); err != nil {
|
||||
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addSyncResponse(); err != nil {
|
||||
return fmt.Errorf("add sync response: %w", err)
|
||||
}
|
||||
@@ -354,6 +363,10 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add systemd logs: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addUpdateLogs(); err != nil {
|
||||
log.Errorf("failed to add updater logs: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -522,6 +535,18 @@ func (g *BundleGenerator) addProf() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addStackTrace() error {
|
||||
buf := make([]byte, 5242880) // 5 MB buffer
|
||||
n := runtime.Stack(buf, true)
|
||||
|
||||
stackTrace := bytes.NewReader(buf[:n])
|
||||
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
|
||||
return fmt.Errorf("add stack trace file to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addInterfaces() error {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
@@ -630,6 +655,29 @@ func (g *BundleGenerator) addStateFile() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addUpdateLogs() error {
|
||||
inst := installer.New()
|
||||
logFiles := inst.LogFiles()
|
||||
if len(logFiles) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("adding updater logs")
|
||||
for _, logFile := range logFiles {
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
log.Warnf("failed to read update log file %s: %v", logFile, err)
|
||||
continue
|
||||
}
|
||||
|
||||
baseName := filepath.Base(logFile)
|
||||
if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil {
|
||||
return fmt.Errorf("add update log file %s to zip: %w", baseName, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addCorruptedStateFiles() error {
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
pattern := sm.GetStatePath()
|
||||
|
||||
@@ -76,6 +76,9 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
|
||||
var records []nbdns.SimpleRecord
|
||||
|
||||
for _, zone := range config.CustomZones {
|
||||
if zone.SkipPTRProcess {
|
||||
continue
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
if record.Type != int(dns.TypeA) {
|
||||
continue
|
||||
@@ -106,8 +109,9 @@ func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
||||
records := collectPTRRecords(config, network)
|
||||
|
||||
reverseZone := nbdns.CustomZone{
|
||||
Domain: zoneName,
|
||||
Records: records,
|
||||
Domain: zoneName,
|
||||
Records: records,
|
||||
SearchDomainDisabled: true,
|
||||
}
|
||||
|
||||
config.CustomZones = append(config.CustomZones, reverseZone)
|
||||
|
||||
@@ -11,11 +11,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4ReverseZone = ".in-addr.arpa."
|
||||
ipv6ReverseZone = ".ip6.arpa."
|
||||
)
|
||||
|
||||
type hostManager interface {
|
||||
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
|
||||
restoreHostDNS() error
|
||||
@@ -110,10 +105,9 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip netip.Addr, port int) H
|
||||
}
|
||||
|
||||
for _, customZone := range dnsConfig.CustomZones {
|
||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||
config.Domains = append(config.Domains, DomainConfig{
|
||||
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||
MatchOnly: matchOnly,
|
||||
MatchOnly: customZone.SearchDomainDisabled,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -26,6 +27,11 @@ type Resolver struct {
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type ipsResponse struct {
|
||||
ips []netip.Addr
|
||||
err error
|
||||
}
|
||||
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
@@ -99,9 +105,9 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
var aRecords, aaaaRecords []dns.RR
|
||||
@@ -159,6 +165,36 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
||||
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
||||
resultChan := make(chan *ipsResponse, 1)
|
||||
|
||||
go func() {
|
||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||
resultChan <- &ipsResponse{
|
||||
err: err,
|
||||
ips: ips,
|
||||
}
|
||||
}()
|
||||
|
||||
var resp *ipsResponse
|
||||
|
||||
select {
|
||||
case <-time.After(dnsTimeout + time.Millisecond*500):
|
||||
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
||||
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case resp = <-resultChan:
|
||||
}
|
||||
|
||||
if resp.err != nil {
|
||||
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
||||
}
|
||||
return resp.ips, nil
|
||||
}
|
||||
|
||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
|
||||
if mgmtURL == nil {
|
||||
|
||||
@@ -80,6 +80,7 @@ type DefaultServer struct {
|
||||
updateSerial uint64
|
||||
previousConfigHash uint64
|
||||
currentConfig HostDNSConfig
|
||||
currentConfigHash uint64
|
||||
handlerChain *HandlerChain
|
||||
extraDomains map[domain.Domain]int
|
||||
|
||||
@@ -207,6 +208,7 @@ func newDefaultServer(
|
||||
hostsDNSHolder: newHostsDNSHolder(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
mgmtCacheResolver: mgmtCacheResolver,
|
||||
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
||||
}
|
||||
|
||||
// register with root zone, handler chain takes care of the routing
|
||||
@@ -586,8 +588,29 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
|
||||
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
|
||||
|
||||
hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||
ZeroNil: true,
|
||||
IgnoreZeroValue: true,
|
||||
SlicesAsSets: true,
|
||||
UseStringer: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("unable to hash the host dns configuration, will apply config anyway: %s", err)
|
||||
// Fall through to apply config anyway (fail-safe approach)
|
||||
} else if s.currentConfigHash == hash {
|
||||
log.Debugf("not applying host config as there are no changes")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("applying host config as there are changes")
|
||||
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Only update hash if it was computed successfully and config was applied
|
||||
if err == nil {
|
||||
s.currentConfigHash = hash
|
||||
}
|
||||
|
||||
s.registerFallback(config)
|
||||
|
||||
@@ -1602,7 +1602,10 @@ func TestExtraDomains(t *testing.T) {
|
||||
"other.example.com.",
|
||||
"duplicate.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 4,
|
||||
// Expect 3 calls instead of 4 because when deregistering duplicate.example.com,
|
||||
// the domain remains in the config (ref count goes from 2 to 1), so the host
|
||||
// config hash doesn't change and applyDNSConfig is not called.
|
||||
applyHostConfigCall: 3,
|
||||
},
|
||||
{
|
||||
name: "Config update with new domains after registration",
|
||||
@@ -1657,7 +1660,10 @@ func TestExtraDomains(t *testing.T) {
|
||||
expectedMatchOnly: []string{
|
||||
"extra.example.com.",
|
||||
},
|
||||
applyHostConfigCall: 3,
|
||||
// Expect 2 calls instead of 3 because when deregistering protected.example.com,
|
||||
// it's removed from extraDomains but still remains in the config (from customZones),
|
||||
// so the host config hash doesn't change and applyDNSConfig is not called.
|
||||
applyHostConfigCall: 2,
|
||||
},
|
||||
{
|
||||
name: "Register domain that is part of nameserver group",
|
||||
|
||||
@@ -197,7 +197,7 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
timeoutMsg += " " + peerInfo
|
||||
}
|
||||
timeoutMsg += fmt.Sprintf(" - error: %v", err)
|
||||
logger.Warnf(timeoutMsg)
|
||||
logger.Warn(timeoutMsg)
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||
|
||||
@@ -234,6 +234,11 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
||||
for i, ip := range ips {
|
||||
ips[i] = ip.Unmap()
|
||||
}
|
||||
|
||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
f.cache.set(domain, question.Qtype, ips)
|
||||
|
||||
@@ -42,14 +42,13 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
|
||||
@@ -73,6 +72,7 @@ const (
|
||||
PeerConnectionTimeoutMax = 45000 // ms
|
||||
PeerConnectionTimeoutMin = 30000 // ms
|
||||
connInitLimit = 200
|
||||
disableAutoUpdate = "disabled"
|
||||
)
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
@@ -201,6 +201,9 @@ type Engine struct {
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// auto-update
|
||||
updateManager *updatemanager.Manager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
|
||||
@@ -221,17 +224,7 @@ type localIpUpdater interface {
|
||||
}
|
||||
|
||||
// NewEngine creates a new Connection Engine with probes attached
|
||||
func NewEngine(
|
||||
clientCtx context.Context,
|
||||
clientCancel context.CancelFunc,
|
||||
signalClient signal.Client,
|
||||
mgmClient mgm.Client,
|
||||
relayManager *relayClient.Manager,
|
||||
config *EngineConfig,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
checks []*mgmProto.Checks,
|
||||
) *Engine {
|
||||
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
@@ -247,28 +240,12 @@ func NewEngine(
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
}
|
||||
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
|
||||
path := sm.GetStatePath()
|
||||
if runtime.GOOS == "ios" {
|
||||
if !fileExists(mobileDep.StateFilePath) {
|
||||
err := createFile(mobileDep.StateFilePath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create state file: %v", err)
|
||||
// we are not exiting as we can run without the state manager
|
||||
}
|
||||
}
|
||||
|
||||
path = mobileDep.StateFilePath
|
||||
}
|
||||
engine.stateManager = statemanager.New(path)
|
||||
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
return engine
|
||||
}
|
||||
@@ -280,7 +257,6 @@ func (e *Engine) Stop() error {
|
||||
return nil
|
||||
}
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
if e.connMgr != nil {
|
||||
e.connMgr.Close()
|
||||
@@ -298,9 +274,6 @@ func (e *Engine) Stop() error {
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.ingressGatewayMgr != nil {
|
||||
if err := e.ingressGatewayMgr.Close(); err != nil {
|
||||
log.Warnf("failed to cleanup forward rules: %v", err)
|
||||
@@ -308,24 +281,33 @@ func (e *Engine) Stop() error {
|
||||
e.ingressGatewayMgr = nil
|
||||
}
|
||||
|
||||
e.stopDNSForwarder()
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop(e.stateManager)
|
||||
}
|
||||
|
||||
if e.srWatcher != nil {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
|
||||
if e.updateManager != nil {
|
||||
e.updateManager.Stop()
|
||||
}
|
||||
|
||||
log.Info("cleaning up status recorder states")
|
||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
|
||||
|
||||
if err := e.removeAllPeers(); err != nil {
|
||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||
log.Errorf("failed to remove all peers: %s", err)
|
||||
}
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop(e.stateManager)
|
||||
}
|
||||
|
||||
e.stopDNSForwarder()
|
||||
|
||||
// stop/restore DNS after peers are closed but before interface goes down
|
||||
// so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
@@ -337,16 +319,18 @@ func (e *Engine) Stop() error {
|
||||
e.flowManager.Close()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
stateCtx, stateCancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer stateCancel()
|
||||
|
||||
if err := e.stateManager.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
||||
if err := e.stateManager.Stop(stateCtx); err != nil {
|
||||
log.Errorf("failed to stop state manager: %v", err)
|
||||
}
|
||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
@@ -432,8 +416,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create rosenpass manager: %w", err)
|
||||
}
|
||||
err := e.rpManager.Run()
|
||||
if err != nil {
|
||||
if err := e.rpManager.Run(); err != nil {
|
||||
return fmt.Errorf("run rosenpass manager: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -485,6 +468,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
}
|
||||
|
||||
if err := e.createFirewall(); err != nil {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -538,6 +522,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
e.handleAutoUpdateVersion(autoUpdateSettings, true)
|
||||
}
|
||||
|
||||
func (e *Engine) createFirewall() error {
|
||||
if e.config.DisableFirewall {
|
||||
log.Infof("firewall is disabled")
|
||||
@@ -746,10 +737,54 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
|
||||
if autoUpdateSettings == nil {
|
||||
return
|
||||
}
|
||||
|
||||
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
||||
|
||||
// Stop and cleanup if disabled
|
||||
if e.updateManager != nil && disabled {
|
||||
log.Infof("auto-update is disabled, stopping update manager")
|
||||
e.updateManager.Stop()
|
||||
e.updateManager = nil
|
||||
return
|
||||
}
|
||||
|
||||
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
|
||||
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
|
||||
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
|
||||
return
|
||||
}
|
||||
|
||||
// Start manager if needed
|
||||
if e.updateManager == nil {
|
||||
log.Infof("starting auto-update manager")
|
||||
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
e.updateManager = updateManager
|
||||
e.updateManager.Start(e.ctx)
|
||||
}
|
||||
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
|
||||
e.updateManager.SetVersion(autoUpdateSettings.Version)
|
||||
}
|
||||
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||
if e.ctx.Err() != nil {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
|
||||
}
|
||||
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
wCfg := update.GetNetbirdConfig()
|
||||
err := e.updateTURNs(wCfg.GetTurns())
|
||||
@@ -1207,7 +1242,9 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
|
||||
|
||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||
dnsZone := nbdns.CustomZone{
|
||||
Domain: zone.GetDomain(),
|
||||
Domain: zone.GetDomain(),
|
||||
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
|
||||
SkipPTRProcess: zone.GetSkipPTRProcess(),
|
||||
}
|
||||
for _, record := range zone.Records {
|
||||
dnsRecord := nbdns.SimpleRecord{
|
||||
@@ -1367,6 +1404,11 @@ func (e *Engine) receiveSignalEvents() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||
if e.ctx.Err() != nil {
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
conn, ok := e.peerStore.PeerConn(msg.Key)
|
||||
if !ok {
|
||||
return fmt.Errorf("wrongly addressed message %s", msg.Key)
|
||||
@@ -1831,6 +1873,18 @@ func (e *Engine) GetWgAddr() netip.Addr {
|
||||
return e.wgInterface.Address().IP
|
||||
}
|
||||
|
||||
func (e *Engine) RenewTun(fd int) error {
|
||||
e.syncMsgMux.Lock()
|
||||
wgInterface := e.wgInterface
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
if wgInterface == nil {
|
||||
return fmt.Errorf("wireguard interface not initialized")
|
||||
}
|
||||
|
||||
return wgInterface.RenewTun(fd)
|
||||
}
|
||||
|
||||
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||
func (e *Engine) updateDNSForwarder(
|
||||
enabled bool,
|
||||
|
||||
@@ -30,11 +30,12 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
@@ -54,7 +55,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -110,6 +110,10 @@ type MockWGIface struct {
|
||||
LastActivitiesFunc func() map[string]monotime.Time
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RenewTun(_ int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -249,6 +253,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
MobileDependency{},
|
||||
peer.NewRecorder("https://mgm"),
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@@ -410,21 +415,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&signal.MockClient{},
|
||||
&mgmt.MockClient{},
|
||||
relayMgr,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun102",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
},
|
||||
MobileDependency{},
|
||||
peer.NewRecorder("https://mgm"),
|
||||
nil)
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
WgIfaceName: "utun102",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
|
||||
wgIface := &MockWGIface{
|
||||
NameFunc: func() string { return "utun102" },
|
||||
@@ -643,7 +640,7 @@ func TestEngine_Sync(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@@ -808,7 +805,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
engine.ctx = ctx
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
@@ -1010,7 +1007,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
engine.ctx = ctx
|
||||
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
@@ -1536,7 +1533,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
@@ -1624,14 +1621,17 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
type wgIfaceBase interface {
|
||||
Create() error
|
||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||
RenewTun(fd int) error
|
||||
IsUserspaceBind() bool
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
|
||||
@@ -20,7 +20,7 @@ type EndpointUpdater struct {
|
||||
wgConfig WgConfig
|
||||
initiator bool
|
||||
|
||||
// mu protects updateWireGuardPeer and cancelFunc
|
||||
// mu protects cancelFunc
|
||||
mu sync.Mutex
|
||||
cancelFunc func()
|
||||
updateWg sync.WaitGroup
|
||||
@@ -86,11 +86,9 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
e.mu.Lock()
|
||||
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
||||
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
|
||||
}
|
||||
e.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
@@ -165,19 +166,26 @@ func getConfigDir() (string, error) {
|
||||
if ConfigDirOverride != "" {
|
||||
return ConfigDirOverride, nil
|
||||
}
|
||||
configDir, err := os.UserConfigDir()
|
||||
|
||||
base, err := baseConfigDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
configDir = filepath.Join(configDir, "netbird")
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return "", err
|
||||
configDir := filepath.Join(base, "netbird")
|
||||
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return configDir, nil
|
||||
}
|
||||
|
||||
func baseConfigDir() (string, error) {
|
||||
if runtime.GOOS == "darwin" {
|
||||
if u, err := user.Current(); err == nil && u.HomeDir != "" {
|
||||
return filepath.Join(u.HomeDir, "Library", "Application Support"), nil
|
||||
}
|
||||
}
|
||||
|
||||
return configDir, nil
|
||||
return os.UserConfigDir()
|
||||
}
|
||||
|
||||
func getConfigDirForUser(username string) (string, error) {
|
||||
|
||||
@@ -76,6 +76,7 @@ func (a *ActiveProfileState) FilePath() (string, error) {
|
||||
}
|
||||
|
||||
type ServiceManager struct {
|
||||
profilesDir string // If set, overrides ConfigDirOverride for profile operations
|
||||
}
|
||||
|
||||
func NewServiceManager(defaultConfigPath string) *ServiceManager {
|
||||
@@ -85,6 +86,17 @@ func NewServiceManager(defaultConfigPath string) *ServiceManager {
|
||||
return &ServiceManager{}
|
||||
}
|
||||
|
||||
// NewServiceManagerWithProfilesDir creates a ServiceManager with a specific profiles directory
|
||||
// This allows setting the profiles directory without modifying the global ConfigDirOverride
|
||||
func NewServiceManagerWithProfilesDir(defaultConfigPath string, profilesDir string) *ServiceManager {
|
||||
if defaultConfigPath != "" {
|
||||
DefaultConfigPath = defaultConfigPath
|
||||
}
|
||||
return &ServiceManager{
|
||||
profilesDir: profilesDir,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
|
||||
|
||||
if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil {
|
||||
@@ -240,7 +252,7 @@ func (s *ServiceManager) DefaultProfilePath() string {
|
||||
}
|
||||
|
||||
func (s *ServiceManager) AddProfile(profileName, username string) error {
|
||||
configDir, err := getConfigDirForUser(username)
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config directory: %w", err)
|
||||
}
|
||||
@@ -270,7 +282,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
|
||||
}
|
||||
|
||||
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||
configDir, err := getConfigDirForUser(username)
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config directory: %w", err)
|
||||
}
|
||||
@@ -302,7 +314,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||
}
|
||||
|
||||
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
|
||||
configDir, err := getConfigDirForUser(username)
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get config directory: %w", err)
|
||||
}
|
||||
@@ -361,7 +373,7 @@ func (s *ServiceManager) GetStatePath() string {
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
configDir, err := getConfigDirForUser(activeProf.Username)
|
||||
configDir, err := s.getConfigDir(activeProf.Username)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err)
|
||||
return defaultStatePath
|
||||
@@ -369,3 +381,12 @@ func (s *ServiceManager) GetStatePath() string {
|
||||
|
||||
return filepath.Join(configDir, activeProf.Name+".state.json")
|
||||
}
|
||||
|
||||
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
|
||||
func (s *ServiceManager) getConfigDir(username string) (string, error) {
|
||||
if s.profilesDir != "" {
|
||||
return s.profilesDir, nil
|
||||
}
|
||||
|
||||
return getConfigDirForUser(username)
|
||||
}
|
||||
|
||||
218
client/internal/sleep/detector_darwin.go
Normal file
218
client/internal/sleep/detector_darwin.go
Normal file
@@ -0,0 +1,218 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package sleep
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
|
||||
#include <IOKit/pwr_mgt/IOPMLib.h>
|
||||
#include <IOKit/IOMessage.h>
|
||||
#include <CoreFoundation/CoreFoundation.h>
|
||||
|
||||
extern void sleepCallbackBridge();
|
||||
extern void poweredOnCallbackBridge();
|
||||
extern void suspendedCallbackBridge();
|
||||
extern void resumedCallbackBridge();
|
||||
|
||||
|
||||
// C global variables for IOKit state
|
||||
static IONotificationPortRef g_notifyPortRef = NULL;
|
||||
static io_object_t g_notifierObject = 0;
|
||||
static io_object_t g_generalInterestNotifier = 0;
|
||||
static io_connect_t g_rootPort = 0;
|
||||
static CFRunLoopRef g_runLoop = NULL;
|
||||
|
||||
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
|
||||
switch (messageType) {
|
||||
case kIOMessageSystemWillSleep:
|
||||
sleepCallbackBridge();
|
||||
IOAllowPowerChange(g_rootPort, (long)messageArgument);
|
||||
break;
|
||||
case kIOMessageSystemHasPoweredOn:
|
||||
poweredOnCallbackBridge();
|
||||
break;
|
||||
case kIOMessageServiceIsSuspended:
|
||||
suspendedCallbackBridge();
|
||||
break;
|
||||
case kIOMessageServiceIsResumed:
|
||||
resumedCallbackBridge();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void registerNotifications() {
|
||||
g_rootPort = IORegisterForSystemPower(
|
||||
NULL,
|
||||
&g_notifyPortRef,
|
||||
(IOServiceInterestCallback)sleepCallback,
|
||||
&g_notifierObject
|
||||
);
|
||||
|
||||
if (g_rootPort == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
CFRunLoopAddSource(CFRunLoopGetCurrent(),
|
||||
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||
kCFRunLoopCommonModes);
|
||||
|
||||
g_runLoop = CFRunLoopGetCurrent();
|
||||
CFRunLoopRun();
|
||||
}
|
||||
|
||||
static void unregisterNotifications() {
|
||||
CFRunLoopRemoveSource(g_runLoop,
|
||||
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||
kCFRunLoopCommonModes);
|
||||
|
||||
IODeregisterForSystemPower(&g_notifierObject);
|
||||
IOServiceClose(g_rootPort);
|
||||
IONotificationPortDestroy(g_notifyPortRef);
|
||||
CFRunLoopStop(g_runLoop);
|
||||
|
||||
g_notifyPortRef = NULL;
|
||||
g_notifierObject = 0;
|
||||
g_rootPort = 0;
|
||||
g_runLoop = NULL;
|
||||
}
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
serviceRegistry = make(map[*Detector]struct{})
|
||||
serviceRegistryMu sync.Mutex
|
||||
)
|
||||
|
||||
//export sleepCallbackBridge
|
||||
func sleepCallbackBridge() {
|
||||
log.Info("sleepCallbackBridge event triggered")
|
||||
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
for svc := range serviceRegistry {
|
||||
svc.triggerCallback(EventTypeSleep)
|
||||
}
|
||||
}
|
||||
|
||||
//export resumedCallbackBridge
|
||||
func resumedCallbackBridge() {
|
||||
log.Info("resumedCallbackBridge event triggered")
|
||||
}
|
||||
|
||||
//export suspendedCallbackBridge
|
||||
func suspendedCallbackBridge() {
|
||||
log.Info("suspendedCallbackBridge event triggered")
|
||||
}
|
||||
|
||||
//export poweredOnCallbackBridge
|
||||
func poweredOnCallbackBridge() {
|
||||
log.Info("poweredOnCallbackBridge event triggered")
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
for svc := range serviceRegistry {
|
||||
svc.triggerCallback(EventTypeWakeUp)
|
||||
}
|
||||
}
|
||||
|
||||
type Detector struct {
|
||||
callback func(event EventType)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewDetector() (*Detector, error) {
|
||||
return &Detector{}, nil
|
||||
}
|
||||
|
||||
func (d *Detector) Register(callback func(event EventType)) error {
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
|
||||
if _, exists := serviceRegistry[d]; exists {
|
||||
return fmt.Errorf("detector service already registered")
|
||||
}
|
||||
|
||||
d.callback = callback
|
||||
|
||||
d.ctx, d.cancel = context.WithCancel(context.Background())
|
||||
|
||||
if len(serviceRegistry) > 0 {
|
||||
serviceRegistry[d] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
serviceRegistry[d] = struct{}{}
|
||||
|
||||
// CFRunLoop must run on a single fixed OS thread
|
||||
go func() {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
C.registerNotifications()
|
||||
}()
|
||||
|
||||
log.Info("sleep detection service started on macOS")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
|
||||
// and the runloop is stopped and cleaned up.
|
||||
func (d *Detector) Deregister() error {
|
||||
serviceRegistryMu.Lock()
|
||||
defer serviceRegistryMu.Unlock()
|
||||
_, exists := serviceRegistry[d]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// cancel and remove this detector
|
||||
d.cancel()
|
||||
delete(serviceRegistry, d)
|
||||
|
||||
// If other Detectors still exist, leave IOKit running
|
||||
if len(serviceRegistry) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("sleep detection service stopping (deregister)")
|
||||
|
||||
// Deregister IOKit notifications, stop runloop, and free resources
|
||||
C.unregisterNotifications()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Detector) triggerCallback(event EventType) {
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
timeout := time.NewTimer(500 * time.Millisecond)
|
||||
defer timeout.Stop()
|
||||
|
||||
cb := d.callback
|
||||
go func(callback func(event EventType)) {
|
||||
log.Info("sleep detection event fired")
|
||||
callback(event)
|
||||
close(doneChan)
|
||||
}(cb)
|
||||
|
||||
select {
|
||||
case <-doneChan:
|
||||
case <-d.ctx.Done():
|
||||
case <-timeout.C:
|
||||
log.Warnf("sleep callback timed out")
|
||||
}
|
||||
}
|
||||
9
client/internal/sleep/detector_notsupported.go
Normal file
9
client/internal/sleep/detector_notsupported.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !darwin || ios
|
||||
|
||||
package sleep
|
||||
|
||||
import "fmt"
|
||||
|
||||
func NewDetector() (detector, error) {
|
||||
return nil, fmt.Errorf("sleep not supported on this platform")
|
||||
}
|
||||
37
client/internal/sleep/service.go
Normal file
37
client/internal/sleep/service.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package sleep
|
||||
|
||||
var (
|
||||
EventTypeUnknown EventType = 0
|
||||
EventTypeSleep EventType = 1
|
||||
EventTypeWakeUp EventType = 2
|
||||
)
|
||||
|
||||
type EventType int
|
||||
|
||||
type detector interface {
|
||||
Register(callback func(eventType EventType)) error
|
||||
Deregister() error
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
detector detector
|
||||
}
|
||||
|
||||
func New() (*Service, error) {
|
||||
d, err := NewDetector()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Service{
|
||||
detector: d,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Register(callback func(eventType EventType)) error {
|
||||
return s.detector.Register(callback)
|
||||
}
|
||||
|
||||
func (s *Service) Deregister() error {
|
||||
return s.detector.Deregister()
|
||||
}
|
||||
35
client/internal/updatemanager/doc.go
Normal file
35
client/internal/updatemanager/doc.go
Normal file
@@ -0,0 +1,35 @@
|
||||
// Package updatemanager provides automatic update management for the NetBird client.
|
||||
// It monitors for new versions, handles update triggers from management server directives,
|
||||
// and orchestrates the download and installation of client updates.
|
||||
//
|
||||
// # Overview
|
||||
//
|
||||
// The update manager operates as a background service that continuously monitors for
|
||||
// available updates and automatically initiates the update process when conditions are met.
|
||||
// It integrates with the installer package to perform the actual installation.
|
||||
//
|
||||
// # Update Flow
|
||||
//
|
||||
// The complete update process follows these steps:
|
||||
//
|
||||
// 1. Manager receives update directive via SetVersion() or detects new version
|
||||
// 2. Manager validates update should proceed (version comparison, rate limiting)
|
||||
// 3. Manager publishes "updating" event to status recorder
|
||||
// 4. Manager persists UpdateState to track update attempt
|
||||
// 5. Manager downloads installer file (.msi or .exe) to temporary directory
|
||||
// 6. Manager triggers installation via installer.RunInstallation()
|
||||
// 7. Installer package handles the actual installation process
|
||||
// 8. On next startup, CheckUpdateSuccess() verifies update completion
|
||||
// 9. Manager publishes success/failure event to status recorder
|
||||
// 10. Manager cleans up UpdateState
|
||||
//
|
||||
// # State Management
|
||||
//
|
||||
// Update state is persisted across restarts to track update attempts:
|
||||
//
|
||||
// - PreUpdateVersion: Version before update attempt
|
||||
// - TargetVersion: Version attempting to update to
|
||||
//
|
||||
// This enables verification of successful updates and appropriate user notification
|
||||
// after the client restarts with the new version.
|
||||
package updatemanager
|
||||
138
client/internal/updatemanager/downloader/downloader.go
Normal file
138
client/internal/updatemanager/downloader/downloader.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
userAgent = "NetBird agent installer/%s"
|
||||
DefaultRetryDelay = 3 * time.Second
|
||||
)
|
||||
|
||||
func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error {
|
||||
log.Debugf("starting download from %s", url)
|
||||
|
||||
out, err := os.Create(dstFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create destination file %q: %w", dstFile, err)
|
||||
}
|
||||
defer func() {
|
||||
if cerr := out.Close(); cerr != nil {
|
||||
log.Warnf("error closing file %q: %v", dstFile, cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
// First attempt
|
||||
err = downloadToFileOnce(ctx, url, out)
|
||||
if err == nil {
|
||||
log.Infof("successfully downloaded file to %s", dstFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
// If retryDelay is 0, don't retry
|
||||
if retryDelay == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Warnf("download failed, retrying after %v: %v", retryDelay, err)
|
||||
|
||||
// Sleep before retry
|
||||
if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil {
|
||||
return fmt.Errorf("download cancelled during retry delay: %w", sleepErr)
|
||||
}
|
||||
|
||||
// Truncate file before retry
|
||||
if err := out.Truncate(0); err != nil {
|
||||
return fmt.Errorf("failed to truncate file on retry: %w", err)
|
||||
}
|
||||
if _, err := out.Seek(0, 0); err != nil {
|
||||
return fmt.Errorf("failed to seek to beginning of file: %w", err)
|
||||
}
|
||||
|
||||
// Second attempt
|
||||
if err := downloadToFileOnce(ctx, url, out); err != nil {
|
||||
return fmt.Errorf("download failed after retry: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("successfully downloaded file to %s", dstFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
// Add User-Agent header
|
||||
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform HTTP request: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if cerr := resp.Body.Close(); cerr != nil {
|
||||
log.Warnf("error closing response body: %v", cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, limit))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func downloadToFileOnce(ctx context.Context, url string, out *os.File) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
// Add User-Agent header
|
||||
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to perform HTTP request: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if cerr := resp.Body.Close(); cerr != nil {
|
||||
log.Warnf("error closing response body: %v", cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(out, resp.Body); err != nil {
|
||||
return fmt.Errorf("failed to write response body to file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sleepWithContext(ctx context.Context, duration time.Duration) error {
|
||||
select {
|
||||
case <-time.After(duration):
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
199
client/internal/updatemanager/downloader/downloader_test.go
Normal file
199
client/internal/updatemanager/downloader/downloader_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
retryDelay = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
func TestDownloadToFile_Success(t *testing.T) {
|
||||
// Create a test server that responds successfully
|
||||
content := "test file content"
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(content))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create a temporary file for download
|
||||
tempDir := t.TempDir()
|
||||
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||
|
||||
// Download the file
|
||||
err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify the file content
|
||||
data, err := os.ReadFile(dstFile)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read downloaded file: %v", err)
|
||||
}
|
||||
|
||||
if string(data) != content {
|
||||
t.Errorf("expected content %q, got %q", content, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadToFile_SuccessAfterRetry(t *testing.T) {
|
||||
content := "test file content after retry"
|
||||
var attemptCount atomic.Int32
|
||||
|
||||
// Create a test server that fails on first attempt, succeeds on second
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempt := attemptCount.Add(1)
|
||||
if attempt == 1 {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte("error"))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(content))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create a temporary file for download
|
||||
tempDir := t.TempDir()
|
||||
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||
|
||||
// Download the file (should succeed after retry)
|
||||
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil {
|
||||
t.Fatalf("expected no error after retry, got: %v", err)
|
||||
}
|
||||
|
||||
// Verify the file content
|
||||
data, err := os.ReadFile(dstFile)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read downloaded file: %v", err)
|
||||
}
|
||||
|
||||
if string(data) != content {
|
||||
t.Errorf("expected content %q, got %q", content, string(data))
|
||||
}
|
||||
|
||||
// Verify it took 2 attempts
|
||||
if attemptCount.Load() != 2 {
|
||||
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadToFile_FailsAfterRetry(t *testing.T) {
|
||||
var attemptCount atomic.Int32
|
||||
|
||||
// Create a test server that always fails
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attemptCount.Add(1)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte("error"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create a temporary file for download
|
||||
tempDir := t.TempDir()
|
||||
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||
|
||||
// Download the file (should fail after retry)
|
||||
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil {
|
||||
t.Fatal("expected error after retry, got nil")
|
||||
}
|
||||
|
||||
// Verify it tried 2 times
|
||||
if attemptCount.Load() != 2 {
|
||||
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) {
|
||||
var attemptCount atomic.Int32
|
||||
|
||||
// Create a test server that always fails
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attemptCount.Add(1)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create a temporary file for download
|
||||
tempDir := t.TempDir()
|
||||
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||
|
||||
// Create a context that will be cancelled during retry delay
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Cancel after a short delay (during the retry sleep)
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Download the file (should fail due to context cancellation during retry)
|
||||
err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile)
|
||||
if err == nil {
|
||||
t.Fatal("expected error due to context cancellation, got nil")
|
||||
}
|
||||
|
||||
// Should have only made 1 attempt (cancelled during retry delay)
|
||||
if attemptCount.Load() != 1 {
|
||||
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadToFile_InvalidURL(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||
|
||||
err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid URL, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadToFile_InvalidDestination(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("test"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Use an invalid destination path
|
||||
err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid destination, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadToFile_NoRetry(t *testing.T) {
|
||||
var attemptCount atomic.Int32
|
||||
|
||||
// Create a test server that always fails
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attemptCount.Add(1)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte("error"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create a temporary file for download
|
||||
tempDir := t.TempDir()
|
||||
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||
|
||||
// Download the file with retryDelay = 0 (should not retry)
|
||||
if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
// Verify it only made 1 attempt (no retry)
|
||||
if attemptCount.Load() != 1 {
|
||||
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
//go:build !windows
|
||||
|
||||
package installer
|
||||
|
||||
func UpdaterBinaryNameWithoutExtension() string {
|
||||
return updaterBinary
|
||||
}
|
||||
11
client/internal/updatemanager/installer/binary_windows.go
Normal file
11
client/internal/updatemanager/installer/binary_windows.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package installer
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func UpdaterBinaryNameWithoutExtension() string {
|
||||
ext := filepath.Ext(updaterBinary)
|
||||
return strings.TrimSuffix(updaterBinary, ext)
|
||||
}
|
||||
111
client/internal/updatemanager/installer/doc.go
Normal file
111
client/internal/updatemanager/installer/doc.go
Normal file
@@ -0,0 +1,111 @@
|
||||
// Package installer provides functionality for managing NetBird application
|
||||
// updates and installations across Windows, macOS. It handles
|
||||
// the complete update lifecycle including artifact download, cryptographic verification,
|
||||
// installation execution, process management, and result reporting.
|
||||
//
|
||||
// # Architecture
|
||||
//
|
||||
// The installer package uses a two-process architecture to enable self-updates:
|
||||
//
|
||||
// 1. Service Process: The main NetBird daemon process that initiates updates
|
||||
// 2. Updater Process: A detached child process that performs the actual installation
|
||||
//
|
||||
// This separation is critical because:
|
||||
// - The service binary cannot update itself while running
|
||||
// - The installer (EXE/MSI/PKG) will terminate the service during installation
|
||||
// - The updater process survives service termination and restarts it after installation
|
||||
// - Results can be communicated back to the service after it restarts
|
||||
//
|
||||
// # Update Flow
|
||||
//
|
||||
// Service Process (RunInstallation):
|
||||
//
|
||||
// 1. Validates target version format (semver)
|
||||
// 2. Determines installer type (EXE, MSI, PKG, or Homebrew)
|
||||
// 3. Downloads installer file from GitHub releases (if applicable)
|
||||
// 4. Verifies installer signature using reposign package (cryptographic verification in service process before
|
||||
// launching updater)
|
||||
// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows)
|
||||
// 6. Launches updater process with detached mode:
|
||||
// - --temp-dir: Temporary directory path
|
||||
// - --service-dir: Service installation directory
|
||||
// - --installer-file: Path to downloaded installer (if applicable)
|
||||
// - --dry-run: Optional flag to test without actually installing
|
||||
// 7. Service process continues running (will be terminated by installer later)
|
||||
// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion
|
||||
//
|
||||
// Updater Process (Setup):
|
||||
//
|
||||
// 1. Receives parameters from service via command-line arguments
|
||||
// 2. Runs installer with appropriate silent/quiet flags:
|
||||
// - Windows EXE: installer.exe /S
|
||||
// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log
|
||||
// - macOS PKG: installer -pkg installer.pkg -target /
|
||||
// - macOS Homebrew: brew upgrade netbirdio/tap/netbird
|
||||
// 3. Installer terminates daemon and UI processes
|
||||
// 4. Installer replaces binaries with new version
|
||||
// 5. Updater waits for installer to complete
|
||||
// 6. Updater restarts daemon:
|
||||
// - Windows: netbird.exe service start
|
||||
// - macOS/Linux: netbird service start
|
||||
// 7. Updater restarts UI:
|
||||
// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser
|
||||
// - macOS: Uses launchctl asuser to launch NetBird.app for console user
|
||||
// - Linux: Not implemented (UI typically auto-starts)
|
||||
// 8. Updater writes result.json with success/error status
|
||||
// 9. Updater process exits
|
||||
//
|
||||
// # Result Communication
|
||||
//
|
||||
// The ResultHandler (result.go) manages communication between updater and service:
|
||||
//
|
||||
// Result Structure:
|
||||
//
|
||||
// type Result struct {
|
||||
// Success bool // true if installation succeeded
|
||||
// Error string // error message if Success is false
|
||||
// ExecutedAt time.Time // when installation completed
|
||||
// }
|
||||
//
|
||||
// Result files are automatically cleaned up after being read.
|
||||
//
|
||||
// # File Locations
|
||||
//
|
||||
// Temporary Directory (platform-specific):
|
||||
//
|
||||
// Windows:
|
||||
// - Path: %ProgramData%\Netbird\tmp-install
|
||||
// - Example: C:\ProgramData\Netbird\tmp-install
|
||||
//
|
||||
// macOS:
|
||||
// - Path: /var/lib/netbird/tmp-install
|
||||
// - Requires root permissions
|
||||
//
|
||||
// Files created during installation:
|
||||
//
|
||||
// tmp-install/
|
||||
// installer.log
|
||||
// updater[.exe] # Copy of service binary
|
||||
// netbird_installer_*.[exe|msi|pkg] # Downloaded installer
|
||||
// result.json # Installation result
|
||||
// msi.log # MSI verbose log (Windows MSI only)
|
||||
//
|
||||
// # API Reference
|
||||
//
|
||||
// # Cleanup
|
||||
//
|
||||
// CleanUpInstallerFiles() removes temporary files after successful installation:
|
||||
// - Downloaded installer files (*.exe, *.msi, *.pkg)
|
||||
// - Updater binary copy
|
||||
// - Does NOT remove result.json (cleaned by ResultHandler after read)
|
||||
// - Does NOT remove msi.log (kept for debugging)
|
||||
//
|
||||
// # Dry-Run Mode
|
||||
//
|
||||
// Dry-run mode allows testing the update process without actually installing:
|
||||
//
|
||||
// Enable via environment variable:
|
||||
//
|
||||
// export NB_AUTO_UPDATE_DRY_RUN=true
|
||||
// netbird service install-update 0.29.0
|
||||
package installer
|
||||
50
client/internal/updatemanager/installer/installer.go
Normal file
50
client/internal/updatemanager/installer/installer.go
Normal file
@@ -0,0 +1,50 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package installer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
updaterBinary = "updater"
|
||||
)
|
||||
|
||||
type Installer struct {
|
||||
tempDir string
|
||||
}
|
||||
|
||||
// New used by the service
|
||||
func New() *Installer {
|
||||
return &Installer{}
|
||||
}
|
||||
|
||||
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
|
||||
func NewWithDir(tempDir string) *Installer {
|
||||
return &Installer{
|
||||
tempDir: tempDir,
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Installer) TempDir() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *Installer) LogFiles() []string {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
func (u *Installer) CleanUpInstallerFiles() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error {
|
||||
return fmt.Errorf("unsupported platform")
|
||||
}
|
||||
|
||||
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
|
||||
// This will be run by the updater process
|
||||
func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) {
|
||||
return fmt.Errorf("unsupported platform")
|
||||
}
|
||||
293
client/internal/updatemanager/installer/installer_common.go
Normal file
293
client/internal/updatemanager/installer/installer_common.go
Normal file
@@ -0,0 +1,293 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package installer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
)
|
||||
|
||||
type Installer struct {
|
||||
tempDir string
|
||||
}
|
||||
|
||||
// New used by the service
|
||||
func New() *Installer {
|
||||
return &Installer{
|
||||
tempDir: defaultTempDir,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
|
||||
func NewWithDir(tempDir string) *Installer {
|
||||
return &Installer{
|
||||
tempDir: tempDir,
|
||||
}
|
||||
}
|
||||
|
||||
// RunInstallation starts the updater process to run the installation
|
||||
// This will run by the original service process
|
||||
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) {
|
||||
resultHandler := NewResultHandler(u.tempDir)
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if writeErr := resultHandler.WriteErr(err); writeErr != nil {
|
||||
log.Errorf("failed to write error result: %v", writeErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if err := validateTargetVersion(targetVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := u.mkTempDir(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var installerFile string
|
||||
// Download files only when not using any third-party store
|
||||
if installerType := TypeOfInstaller(ctx); installerType.Downloadable() {
|
||||
log.Infof("download installer")
|
||||
var err error
|
||||
installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion)
|
||||
if err != nil {
|
||||
log.Errorf("failed to download installer: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create artifact verify: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil {
|
||||
log.Errorf("artifact verification error: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("running installer")
|
||||
updaterPath, err := u.copyUpdater()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// the directory where the service has been installed
|
||||
workspace, err := getServiceDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"--temp-dir", u.tempDir,
|
||||
"--service-dir", workspace,
|
||||
}
|
||||
|
||||
if isDryRunEnabled() {
|
||||
args = append(args, "--dry-run=true")
|
||||
}
|
||||
|
||||
if installerFile != "" {
|
||||
args = append(args, "--installer-file", installerFile)
|
||||
}
|
||||
|
||||
updateCmd := exec.Command(updaterPath, args...)
|
||||
log.Infof("starting updater process: %s", updateCmd.String())
|
||||
|
||||
// Configure the updater to run in a separate session/process group
|
||||
// so it survives the parent daemon being stopped
|
||||
setUpdaterProcAttr(updateCmd)
|
||||
|
||||
// Start the updater process asynchronously
|
||||
if err := updateCmd.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pid := updateCmd.Process.Pid
|
||||
log.Infof("updater started with PID %d", pid)
|
||||
|
||||
// Release the process so the OS can fully detach it
|
||||
if err := updateCmd.Process.Release(); err != nil {
|
||||
log.Warnf("failed to release updater process: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanUpInstallerFiles
|
||||
// - the installer file (pkg, exe, msi)
|
||||
// - the selfcopy updater.exe
|
||||
func (u *Installer) CleanUpInstallerFiles() error {
|
||||
// Check if tempDir exists
|
||||
info, err := os.Stat(u.tempDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) {
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err))
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(u.tempDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := entry.Name()
|
||||
for _, ext := range binaryExtensions {
|
||||
if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) {
|
||||
if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err))
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return merr.ErrorOrNil()
|
||||
}
|
||||
|
||||
func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) {
|
||||
fileURL := urlWithVersionArch(installerType, targetVersion)
|
||||
|
||||
// Clean up temp directory on error
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
if err := os.RemoveAll(u.tempDir); err != nil {
|
||||
log.Errorf("error cleaning up temporary directory: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
fileName := path.Base(fileURL)
|
||||
if fileName == "." || fileName == "/" || fileName == "" {
|
||||
return "", fmt.Errorf("invalid file URL: %s", fileURL)
|
||||
}
|
||||
|
||||
outputFilePath := filepath.Join(u.tempDir, fileName)
|
||||
if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
success = true
|
||||
return outputFilePath, nil
|
||||
}
|
||||
|
||||
func (u *Installer) TempDir() string {
|
||||
return u.tempDir
|
||||
}
|
||||
|
||||
func (u *Installer) mkTempDir() error {
|
||||
if err := os.MkdirAll(u.tempDir, 0o755); err != nil {
|
||||
log.Debugf("failed to create tempdir: %s", u.tempDir)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Installer) copyUpdater() (string, error) {
|
||||
src, err := getServiceBinary()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get updater binary: %w", err)
|
||||
}
|
||||
|
||||
dst := filepath.Join(u.tempDir, updaterBinary)
|
||||
if err := copyFile(src, dst); err != nil {
|
||||
return "", fmt.Errorf("failed to copy updater binary: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(dst, 0o755); err != nil {
|
||||
return "", fmt.Errorf("failed to set permissions: %w", err)
|
||||
}
|
||||
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
func validateTargetVersion(targetVersion string) error {
|
||||
if targetVersion == "" {
|
||||
return fmt.Errorf("target version cannot be empty")
|
||||
}
|
||||
|
||||
_, err := goversion.NewVersion(targetVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid target version %q: %w", targetVersion, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyFile(src, dst string) error {
|
||||
log.Infof("copying %s to %s", src, dst)
|
||||
in, err := os.Open(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open source: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := in.Close(); err != nil {
|
||||
log.Warnf("failed to close source file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
out, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create destination: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := out.Close(); err != nil {
|
||||
log.Warnf("failed to close destination file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(out, in); err != nil {
|
||||
return fmt.Errorf("copy: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getServiceDir() (string, error) {
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Dir(exePath), nil
|
||||
}
|
||||
|
||||
func getServiceBinary() (string, error) {
|
||||
return os.Executable()
|
||||
}
|
||||
|
||||
func isDryRunEnabled() bool {
|
||||
return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true")
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package installer
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func (u *Installer) LogFiles() []string {
|
||||
return []string{
|
||||
filepath.Join(u.tempDir, LogFile),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package installer
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func (u *Installer) LogFiles() []string {
|
||||
return []string{
|
||||
filepath.Join(u.tempDir, msiLogFile),
|
||||
filepath.Join(u.tempDir, LogFile),
|
||||
}
|
||||
}
|
||||
238
client/internal/updatemanager/installer/installer_run_darwin.go
Normal file
238
client/internal/updatemanager/installer/installer_run_darwin.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package installer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
daemonName = "netbird"
|
||||
updaterBinary = "updater"
|
||||
uiBinary = "/Applications/NetBird.app"
|
||||
|
||||
defaultTempDir = "/var/lib/netbird/tmp-install"
|
||||
|
||||
pkgDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
|
||||
)
|
||||
|
||||
var (
|
||||
binaryExtensions = []string{"pkg"}
|
||||
)
|
||||
|
||||
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
|
||||
// This will be run by the updater process
|
||||
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
|
||||
resultHandler := NewResultHandler(u.tempDir)
|
||||
|
||||
// Always ensure daemon and UI are restarted after setup
|
||||
defer func() {
|
||||
log.Infof("write out result")
|
||||
var err error
|
||||
if resultErr == nil {
|
||||
err = resultHandler.WriteSuccess()
|
||||
} else {
|
||||
err = resultHandler.WriteErr(resultErr)
|
||||
}
|
||||
if err != nil {
|
||||
log.Errorf("failed to write update result: %v", err)
|
||||
}
|
||||
|
||||
// skip service restart if dry-run mode is enabled
|
||||
if dryRun {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("starting daemon back")
|
||||
if err := u.startDaemon(daemonFolder); err != nil {
|
||||
log.Errorf("failed to start daemon: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("starting UI back")
|
||||
if err := u.startUIAsUser(); err != nil {
|
||||
log.Errorf("failed to start UI: %v", err)
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
if dryRun {
|
||||
time.Sleep(7 * time.Second)
|
||||
log.Infof("dry-run mode enabled, skipping actual installation")
|
||||
resultErr = fmt.Errorf("dry-run mode enabled")
|
||||
return
|
||||
}
|
||||
|
||||
switch TypeOfInstaller(ctx) {
|
||||
case TypePKG:
|
||||
resultErr = u.installPkgFile(ctx, installerFile)
|
||||
case TypeHomebrew:
|
||||
resultErr = u.updateHomeBrew(ctx)
|
||||
}
|
||||
|
||||
return resultErr
|
||||
}
|
||||
|
||||
func (u *Installer) startDaemon(daemonFolder string) error {
|
||||
log.Infof("starting netbird service")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Warnf("failed to start netbird service: %v, output: %s", err, string(output))
|
||||
return err
|
||||
}
|
||||
log.Infof("netbird service started successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Installer) startUIAsUser() error {
|
||||
log.Infof("starting netbird-ui: %s", uiBinary)
|
||||
|
||||
// Get the current console user
|
||||
cmd := exec.Command("stat", "-f", "%Su", "/dev/console")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get console user: %w", err)
|
||||
}
|
||||
|
||||
username := strings.TrimSpace(string(output))
|
||||
if username == "" || username == "root" {
|
||||
return fmt.Errorf("no active user session found")
|
||||
}
|
||||
|
||||
log.Infof("starting UI for user: %s", username)
|
||||
|
||||
// Get user's UID
|
||||
userInfo, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to lookup user %s: %w", username, err)
|
||||
}
|
||||
|
||||
// Start the UI process as the console user using launchctl
|
||||
// This ensures the app runs in the user's context with proper GUI access
|
||||
launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary)
|
||||
log.Infof("launchCmd: %s", launchCmd.String())
|
||||
// Set the user's home directory for proper macOS app behavior
|
||||
launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir)
|
||||
log.Infof("set HOME environment variable: %s", userInfo.HomeDir)
|
||||
|
||||
if err := launchCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start UI process: %w", err)
|
||||
}
|
||||
|
||||
// Release the process so it can run independently
|
||||
if err := launchCmd.Process.Release(); err != nil {
|
||||
log.Warnf("failed to release UI process: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("netbird-ui started successfully for user %s", username)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Installer) installPkgFile(ctx context.Context, path string) error {
|
||||
log.Infof("installing pkg file: %s", path)
|
||||
|
||||
// Kill any existing UI processes before installation
|
||||
// This ensures the postinstall script's "open $APP" will start the new version
|
||||
u.killUI()
|
||||
|
||||
volume := "/"
|
||||
|
||||
cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("error running pkg file: %w", err)
|
||||
}
|
||||
log.Infof("installer started with PID %d", cmd.Process.Pid)
|
||||
if err := cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("error running pkg file: %w", err)
|
||||
}
|
||||
log.Infof("pkg file installed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Installer) updateHomeBrew(ctx context.Context) error {
|
||||
log.Infof("updating homebrew")
|
||||
|
||||
// Kill any existing UI processes before upgrade
|
||||
// This ensures the new version will be started after upgrade
|
||||
u.killUI()
|
||||
|
||||
// Homebrew must be run as a non-root user
|
||||
// To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
|
||||
// Check both Apple Silicon and Intel Mac paths
|
||||
brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/"
|
||||
brewBinPath := "/opt/homebrew/bin/brew"
|
||||
if _, err := os.Stat(brewTapPath); os.IsNotExist(err) {
|
||||
// Try Intel Mac path
|
||||
brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/"
|
||||
brewBinPath = "/usr/local/bin/brew"
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(brewTapPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting homebrew installation path info: %w", err)
|
||||
}
|
||||
|
||||
fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
|
||||
}
|
||||
|
||||
// Get username from UID
|
||||
brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error looking up brew installer user: %w", err)
|
||||
}
|
||||
userName := brewUser.Username
|
||||
// Get user HOME, required for brew to run correctly
|
||||
// https://github.com/Homebrew/brew/issues/15833
|
||||
homeDir := brewUser.HomeDir
|
||||
|
||||
// Check if netbird-ui is installed (must run as the brew user, not root)
|
||||
checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui")
|
||||
checkUICmd.Env = append(os.Environ(), "HOME="+homeDir)
|
||||
uiInstalled := checkUICmd.Run() == nil
|
||||
|
||||
// Homebrew does not support installing specific versions
|
||||
// Thus it will always update to latest and ignore targetVersion
|
||||
upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"}
|
||||
if uiInstalled {
|
||||
upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...)
|
||||
cmd.Env = append(os.Environ(), "HOME="+homeDir)
|
||||
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output))
|
||||
}
|
||||
|
||||
log.Infof("homebrew updated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Installer) killUI() {
|
||||
log.Infof("killing existing netbird-ui processes")
|
||||
cmd := exec.Command("pkill", "-x", "netbird-ui")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
// pkill returns exit code 1 if no processes matched, which is fine
|
||||
log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output))
|
||||
} else {
|
||||
log.Infof("netbird-ui processes killed")
|
||||
}
|
||||
}
|
||||
|
||||
func urlWithVersionArch(_ Type, version string) string {
|
||||
url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
|
||||
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
|
||||
}
|
||||
213
client/internal/updatemanager/installer/installer_run_windows.go
Normal file
213
client/internal/updatemanager/installer/installer_run_windows.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package installer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
daemonName = "netbird.exe"
|
||||
uiName = "netbird-ui.exe"
|
||||
updaterBinary = "updater.exe"
|
||||
|
||||
msiLogFile = "msi.log"
|
||||
|
||||
msiDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
|
||||
exeDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install")
|
||||
|
||||
// for the cleanup
|
||||
binaryExtensions = []string{"msi", "exe"}
|
||||
)
|
||||
|
||||
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
|
||||
// This will be run by the updater process
|
||||
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
|
||||
resultHandler := NewResultHandler(u.tempDir)
|
||||
|
||||
// Always ensure daemon and UI are restarted after setup
|
||||
defer func() {
|
||||
log.Infof("starting daemon back")
|
||||
if err := u.startDaemon(daemonFolder); err != nil {
|
||||
log.Errorf("failed to start daemon: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("starting UI back")
|
||||
if err := u.startUIAsUser(daemonFolder); err != nil {
|
||||
log.Errorf("failed to start UI: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("write out result")
|
||||
var err error
|
||||
if resultErr == nil {
|
||||
err = resultHandler.WriteSuccess()
|
||||
} else {
|
||||
err = resultHandler.WriteErr(resultErr)
|
||||
}
|
||||
if err != nil {
|
||||
log.Errorf("failed to write update result: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if dryRun {
|
||||
log.Infof("dry-run mode enabled, skipping actual installation")
|
||||
resultErr = fmt.Errorf("dry-run mode enabled")
|
||||
return
|
||||
}
|
||||
|
||||
installerType, err := typeByFileExtension(installerFile)
|
||||
if err != nil {
|
||||
log.Debugf("%v", err)
|
||||
resultErr = err
|
||||
return
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
switch installerType {
|
||||
case TypeExe:
|
||||
log.Infof("run exe installer: %s", installerFile)
|
||||
cmd = exec.CommandContext(ctx, installerFile, "/S")
|
||||
default:
|
||||
installerDir := filepath.Dir(installerFile)
|
||||
logPath := filepath.Join(installerDir, msiLogFile)
|
||||
log.Infof("run msi installer: %s", installerFile)
|
||||
cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath)
|
||||
}
|
||||
|
||||
cmd.Dir = filepath.Dir(installerFile)
|
||||
|
||||
if resultErr = cmd.Start(); resultErr != nil {
|
||||
log.Errorf("error starting installer: %v", resultErr)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("installer started with PID %d", cmd.Process.Pid)
|
||||
if resultErr = cmd.Wait(); resultErr != nil {
|
||||
log.Errorf("installer process finished with error: %v", resultErr)
|
||||
return
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Installer) startDaemon(daemonFolder string) error {
|
||||
log.Infof("starting netbird service")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
log.Debugf("failed to start netbird service: %v, output: %s", err, string(output))
|
||||
return err
|
||||
}
|
||||
log.Infof("netbird service started successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Installer) startUIAsUser(daemonFolder string) error {
|
||||
uiPath := filepath.Join(daemonFolder, uiName)
|
||||
log.Infof("starting netbird-ui: %s", uiPath)
|
||||
|
||||
// Get the active console session ID
|
||||
sessionID := windows.WTSGetActiveConsoleSessionId()
|
||||
if sessionID == 0xFFFFFFFF {
|
||||
return fmt.Errorf("no active user session found")
|
||||
}
|
||||
|
||||
// Get the user token for that session
|
||||
var userToken windows.Token
|
||||
err := windows.WTSQueryUserToken(sessionID, &userToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := userToken.Close(); err != nil {
|
||||
log.Warnf("failed to close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Duplicate the token to a primary token
|
||||
var primaryToken windows.Token
|
||||
err = windows.DuplicateTokenEx(
|
||||
userToken,
|
||||
windows.MAXIMUM_ALLOWED,
|
||||
nil,
|
||||
windows.SecurityImpersonation,
|
||||
windows.TokenPrimary,
|
||||
&primaryToken,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to duplicate token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := primaryToken.Close(); err != nil {
|
||||
log.Warnf("failed to close token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Prepare startup info
|
||||
var si windows.StartupInfo
|
||||
si.Cb = uint32(unsafe.Sizeof(si))
|
||||
si.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
|
||||
|
||||
var pi windows.ProcessInformation
|
||||
|
||||
cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert path to UTF16: %w", err)
|
||||
}
|
||||
|
||||
creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT
|
||||
|
||||
err = windows.CreateProcessAsUser(
|
||||
primaryToken,
|
||||
nil,
|
||||
cmdLine,
|
||||
nil,
|
||||
nil,
|
||||
false,
|
||||
creationFlags,
|
||||
nil,
|
||||
nil,
|
||||
&si,
|
||||
&pi,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("CreateProcessAsUser failed: %w", err)
|
||||
}
|
||||
|
||||
// Close handles
|
||||
if err := windows.CloseHandle(pi.Process); err != nil {
|
||||
log.Warnf("failed to close process handle: %v", err)
|
||||
}
|
||||
if err := windows.CloseHandle(pi.Thread); err != nil {
|
||||
log.Warnf("failed to close thread handle: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("netbird-ui started successfully in session %d", sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func urlWithVersionArch(it Type, version string) string {
|
||||
var url string
|
||||
if it == TypeExe {
|
||||
url = exeDownloadURL
|
||||
} else {
|
||||
url = msiDownloadURL
|
||||
}
|
||||
url = strings.ReplaceAll(url, "%version", version)
|
||||
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
|
||||
}
|
||||
5
client/internal/updatemanager/installer/log.go
Normal file
5
client/internal/updatemanager/installer/log.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package installer
|
||||
|
||||
const (
|
||||
LogFile = "installer.log"
|
||||
)
|
||||
15
client/internal/updatemanager/installer/procattr_darwin.go
Normal file
15
client/internal/updatemanager/installer/procattr_darwin.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package installer
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setUpdaterProcAttr configures the updater process to run in a new session,
|
||||
// making it independent of the parent daemon process. This ensures the updater
|
||||
// survives when the daemon is stopped during the pkg installation.
|
||||
func setUpdaterProcAttr(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setsid: true,
|
||||
}
|
||||
}
|
||||
14
client/internal/updatemanager/installer/procattr_windows.go
Normal file
14
client/internal/updatemanager/installer/procattr_windows.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package installer
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// setUpdaterProcAttr configures the updater process to run detached from the parent,
|
||||
// making it independent of the parent daemon process.
|
||||
func setUpdaterProcAttr(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS
|
||||
}
|
||||
}
|
||||
7
client/internal/updatemanager/installer/repourl_dev.go
Normal file
7
client/internal/updatemanager/installer/repourl_dev.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build devartifactsign
|
||||
|
||||
package installer
|
||||
|
||||
const (
|
||||
DefaultSigningKeysBaseURL = "http://192.168.0.10:9089/signrepo"
|
||||
)
|
||||
7
client/internal/updatemanager/installer/repourl_prod.go
Normal file
7
client/internal/updatemanager/installer/repourl_prod.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !devartifactsign
|
||||
|
||||
package installer
|
||||
|
||||
const (
|
||||
DefaultSigningKeysBaseURL = "https://publickeys.netbird.io/artifact-signatures"
|
||||
)
|
||||
230
client/internal/updatemanager/installer/result.go
Normal file
230
client/internal/updatemanager/installer/result.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package installer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
resultFile = "result.json"
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Success bool
|
||||
Error string
|
||||
ExecutedAt time.Time
|
||||
}
|
||||
|
||||
// ResultHandler handles reading and writing update results
|
||||
type ResultHandler struct {
|
||||
resultFile string
|
||||
}
|
||||
|
||||
// NewResultHandler creates a new communicator with the given directory path
|
||||
// The result file will be created as "result.json" in the specified directory
|
||||
func NewResultHandler(installerDir string) *ResultHandler {
|
||||
// Create it if it doesn't exist
|
||||
// do not care if already exists
|
||||
_ = os.MkdirAll(installerDir, 0o700)
|
||||
|
||||
rh := &ResultHandler{
|
||||
resultFile: filepath.Join(installerDir, resultFile),
|
||||
}
|
||||
return rh
|
||||
}
|
||||
|
||||
func (rh *ResultHandler) GetErrorResultReason() string {
|
||||
result, err := rh.tryReadResult()
|
||||
if err == nil && !result.Success {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if err := rh.cleanup(); err != nil {
|
||||
log.Warnf("failed to cleanup result file: %v", err)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (rh *ResultHandler) WriteSuccess() error {
|
||||
result := Result{
|
||||
Success: true,
|
||||
ExecutedAt: time.Now(),
|
||||
}
|
||||
return rh.write(result)
|
||||
}
|
||||
|
||||
func (rh *ResultHandler) WriteErr(errReason error) error {
|
||||
result := Result{
|
||||
Success: false,
|
||||
Error: errReason.Error(),
|
||||
ExecutedAt: time.Now(),
|
||||
}
|
||||
return rh.write(result)
|
||||
}
|
||||
|
||||
func (rh *ResultHandler) Watch(ctx context.Context) (Result, error) {
|
||||
log.Infof("start watching result: %s", rh.resultFile)
|
||||
|
||||
// Check if file already exists (updater finished before we started watching)
|
||||
if result, err := rh.tryReadResult(); err == nil {
|
||||
log.Infof("installer result: %v", result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
dir := filepath.Dir(rh.resultFile)
|
||||
|
||||
if err := rh.waitForDirectory(ctx, dir); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
return rh.watchForResultFile(ctx, dir)
|
||||
}
|
||||
|
||||
func (rh *ResultHandler) waitForDirectory(ctx context.Context, dir string) error {
|
||||
ticker := time.NewTicker(300 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rh *ResultHandler) watchForResultFile(ctx context.Context, dir string) (Result, error) {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := watcher.Close(); err != nil {
|
||||
log.Warnf("failed to close watcher: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := watcher.Add(dir); err != nil {
|
||||
return Result{}, fmt.Errorf("failed to watch directory: %v", err)
|
||||
}
|
||||
|
||||
// Check again after setting up watcher to avoid race condition
|
||||
// (file could have been created between initial check and watcher setup)
|
||||
if result, err := rh.tryReadResult(); err == nil {
|
||||
log.Infof("installer result: %v", result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return Result{}, ctx.Err()
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return Result{}, errors.New("watcher closed unexpectedly")
|
||||
}
|
||||
|
||||
if result, done := rh.handleWatchEvent(event); done {
|
||||
return result, nil
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return Result{}, errors.New("watcher closed unexpectedly")
|
||||
}
|
||||
return Result{}, fmt.Errorf("watcher error: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rh *ResultHandler) handleWatchEvent(event fsnotify.Event) (Result, bool) {
|
||||
if event.Name != rh.resultFile {
|
||||
return Result{}, false
|
||||
}
|
||||
|
||||
if event.Has(fsnotify.Create) {
|
||||
result, err := rh.tryReadResult()
|
||||
if err != nil {
|
||||
log.Debugf("error while reading result: %v", err)
|
||||
return result, true
|
||||
}
|
||||
log.Infof("installer result: %v", result)
|
||||
return result, true
|
||||
}
|
||||
|
||||
return Result{}, false
|
||||
}
|
||||
|
||||
// Write writes the update result to a file for the UI to read
|
||||
func (rh *ResultHandler) write(result Result) error {
|
||||
log.Infof("write out installer result to: %s", rh.resultFile)
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(rh.resultFile)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
log.Errorf("failed to create directory %s: %v", dir, err)
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write to a temporary file first, then rename for atomic operation
|
||||
tmpPath := rh.resultFile + ".tmp"
|
||||
if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
|
||||
log.Errorf("failed to create temp file: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tmpPath, rh.resultFile); err != nil {
|
||||
if cleanupErr := os.Remove(tmpPath); cleanupErr != nil {
|
||||
log.Warnf("Failed to remove temp result file: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rh *ResultHandler) cleanup() error {
|
||||
err := os.Remove(rh.resultFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
log.Debugf("delete installer result file: %s", rh.resultFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
// tryReadResult attempts to read and validate the result file
|
||||
func (rh *ResultHandler) tryReadResult() (Result, error) {
|
||||
data, err := os.ReadFile(rh.resultFile)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
var result Result
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return Result{}, fmt.Errorf("invalid result format: %w", err)
|
||||
}
|
||||
|
||||
if err := rh.cleanup(); err != nil {
|
||||
log.Warnf("failed to cleanup result file: %v", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
14
client/internal/updatemanager/installer/types.go
Normal file
14
client/internal/updatemanager/installer/types.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package installer
|
||||
|
||||
type Type struct {
|
||||
name string
|
||||
downloadable bool
|
||||
}
|
||||
|
||||
func (t Type) String() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t Type) Downloadable() bool {
|
||||
return t.downloadable
|
||||
}
|
||||
22
client/internal/updatemanager/installer/types_darwin.go
Normal file
22
client/internal/updatemanager/installer/types_darwin.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package installer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
var (
|
||||
TypeHomebrew = Type{name: "Homebrew", downloadable: false}
|
||||
TypePKG = Type{name: "pkg", downloadable: true}
|
||||
)
|
||||
|
||||
func TypeOfInstaller(ctx context.Context) Type {
|
||||
cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client")
|
||||
_, err := cmd.Output()
|
||||
if err != nil && cmd.ProcessState.ExitCode() == 1 {
|
||||
// Not installed using pkg file, thus installed using Homebrew
|
||||
|
||||
return TypeHomebrew
|
||||
}
|
||||
return TypePKG
|
||||
}
|
||||
51
client/internal/updatemanager/installer/types_windows.go
Normal file
51
client/internal/updatemanager/installer/types_windows.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package installer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
const (
|
||||
uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
|
||||
uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
|
||||
)
|
||||
|
||||
var (
|
||||
TypeExe = Type{name: "EXE", downloadable: true}
|
||||
TypeMSI = Type{name: "MSI", downloadable: true}
|
||||
)
|
||||
|
||||
func TypeOfInstaller(_ context.Context) Type {
|
||||
paths := []string{uninstallKeyPath64, uninstallKeyPath32}
|
||||
|
||||
for _, path := range paths {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := k.Close(); err != nil {
|
||||
log.Warnf("Error closing registry key: %v", err)
|
||||
}
|
||||
return TypeExe
|
||||
|
||||
}
|
||||
|
||||
log.Debug("No registry entry found for Netbird, assuming MSI installation")
|
||||
return TypeMSI
|
||||
}
|
||||
|
||||
func typeByFileExtension(filePath string) (Type, error) {
|
||||
switch {
|
||||
case strings.HasSuffix(strings.ToLower(filePath), ".exe"):
|
||||
return TypeExe, nil
|
||||
case strings.HasSuffix(strings.ToLower(filePath), ".msi"):
|
||||
return TypeMSI, nil
|
||||
default:
|
||||
return Type{}, fmt.Errorf("unsupported installer type for file: %s", filePath)
|
||||
}
|
||||
}
|
||||
374
client/internal/updatemanager/manager.go
Normal file
374
client/internal/updatemanager/manager.go
Normal file
@@ -0,0 +1,374 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
latestVersion = "latest"
|
||||
// this version will be ignored
|
||||
developmentVersion = "development"
|
||||
)
|
||||
|
||||
var errNoUpdateState = errors.New("no update state found")
|
||||
|
||||
type UpdateState struct {
|
||||
PreUpdateVersion string
|
||||
TargetVersion string
|
||||
}
|
||||
|
||||
func (u UpdateState) Name() string {
|
||||
return "autoUpdate"
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
statusRecorder *peer.Status
|
||||
stateManager *statemanager.Manager
|
||||
|
||||
lastTrigger time.Time
|
||||
mgmUpdateChan chan struct{}
|
||||
updateChannel chan struct{}
|
||||
currentVersion string
|
||||
update UpdateInterface
|
||||
wg sync.WaitGroup
|
||||
|
||||
cancel context.CancelFunc
|
||||
|
||||
expectedVersion *v.Version
|
||||
updateToLatestVersion bool
|
||||
|
||||
// updateMutex protect update and expectedVersion fields
|
||||
updateMutex sync.Mutex
|
||||
|
||||
triggerUpdateFn func(context.Context, string) error
|
||||
}
|
||||
|
||||
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
|
||||
if runtime.GOOS == "darwin" {
|
||||
isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
|
||||
if isBrew {
|
||||
log.Warnf("auto-update disabled on Home Brew installation")
|
||||
return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
|
||||
}
|
||||
}
|
||||
return newManager(statusRecorder, stateManager)
|
||||
}
|
||||
|
||||
func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
|
||||
manager := &Manager{
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
mgmUpdateChan: make(chan struct{}, 1),
|
||||
updateChannel: make(chan struct{}, 1),
|
||||
currentVersion: version.NetbirdVersion(),
|
||||
update: version.NewUpdate("nb/client"),
|
||||
}
|
||||
manager.triggerUpdateFn = manager.triggerUpdate
|
||||
|
||||
stateManager.RegisterState(&UpdateState{})
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
// CheckUpdateSuccess checks if the update was successful and send a notification.
|
||||
// It works without to start the update manager.
|
||||
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
|
||||
reason := m.lastResultErrReason()
|
||||
if reason != "" {
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_ERROR,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Auto-update failed",
|
||||
fmt.Sprintf("Auto-update failed: %s", reason),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
updateState, err := m.loadAndDeleteUpdateState(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, errNoUpdateState) {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to load update state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("auto-update state loaded, %v", *updateState)
|
||||
|
||||
if updateState.TargetVersion == m.currentVersion {
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Auto-update completed",
|
||||
fmt.Sprintf("Your NetBird Client was auto-updated to version %s", m.currentVersion),
|
||||
nil,
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
if m.cancel != nil {
|
||||
log.Errorf("Manager already started")
|
||||
return
|
||||
}
|
||||
|
||||
m.update.SetDaemonVersion(m.currentVersion)
|
||||
m.update.SetOnUpdateListener(func() {
|
||||
select {
|
||||
case m.updateChannel <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
})
|
||||
go m.update.StartFetcher()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
m.cancel = cancel
|
||||
|
||||
m.wg.Add(1)
|
||||
go m.updateLoop(ctx)
|
||||
}
|
||||
|
||||
func (m *Manager) SetVersion(expectedVersion string) {
|
||||
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
|
||||
if m.cancel == nil {
|
||||
log.Errorf("manager not started")
|
||||
return
|
||||
}
|
||||
|
||||
m.updateMutex.Lock()
|
||||
defer m.updateMutex.Unlock()
|
||||
|
||||
if expectedVersion == "" {
|
||||
log.Errorf("empty expected version provided")
|
||||
m.expectedVersion = nil
|
||||
m.updateToLatestVersion = false
|
||||
return
|
||||
}
|
||||
|
||||
if expectedVersion == latestVersion {
|
||||
m.updateToLatestVersion = true
|
||||
m.expectedVersion = nil
|
||||
} else {
|
||||
expectedSemVer, err := v.NewVersion(expectedVersion)
|
||||
if err != nil {
|
||||
log.Errorf("error parsing version: %v", err)
|
||||
return
|
||||
}
|
||||
if m.expectedVersion != nil && m.expectedVersion.Equal(expectedSemVer) {
|
||||
return
|
||||
}
|
||||
m.expectedVersion = expectedSemVer
|
||||
m.updateToLatestVersion = false
|
||||
}
|
||||
|
||||
select {
|
||||
case m.mgmUpdateChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) Stop() {
|
||||
if m.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
m.cancel()
|
||||
m.updateMutex.Lock()
|
||||
if m.update != nil {
|
||||
m.update.StopWatch()
|
||||
m.update = nil
|
||||
}
|
||||
m.updateMutex.Unlock()
|
||||
|
||||
m.wg.Wait()
|
||||
}
|
||||
|
||||
func (m *Manager) onContextCancel() {
|
||||
if m.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
m.updateMutex.Lock()
|
||||
defer m.updateMutex.Unlock()
|
||||
if m.update != nil {
|
||||
m.update.StopWatch()
|
||||
m.update = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) updateLoop(ctx context.Context) {
|
||||
defer m.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.onContextCancel()
|
||||
return
|
||||
case <-m.mgmUpdateChan:
|
||||
case <-m.updateChannel:
|
||||
log.Infof("fetched new version info")
|
||||
}
|
||||
|
||||
m.handleUpdate(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) handleUpdate(ctx context.Context) {
|
||||
var updateVersion *v.Version
|
||||
|
||||
m.updateMutex.Lock()
|
||||
if m.update == nil {
|
||||
m.updateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
expectedVersion := m.expectedVersion
|
||||
useLatest := m.updateToLatestVersion
|
||||
curLatestVersion := m.update.LatestVersion()
|
||||
m.updateMutex.Unlock()
|
||||
|
||||
switch {
|
||||
// Resolve "latest" to actual version
|
||||
case useLatest:
|
||||
if curLatestVersion == nil {
|
||||
log.Tracef("latest version not fetched yet")
|
||||
return
|
||||
}
|
||||
updateVersion = curLatestVersion
|
||||
// Update to specific version
|
||||
case expectedVersion != nil:
|
||||
updateVersion = expectedVersion
|
||||
default:
|
||||
log.Debugf("no expected version information set")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
|
||||
if !m.shouldUpdate(updateVersion) {
|
||||
return
|
||||
}
|
||||
|
||||
m.lastTrigger = time.Now()
|
||||
log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Automatically updating client",
|
||||
"Your client version is older than auto-update version set in Management, updating client now.",
|
||||
nil,
|
||||
)
|
||||
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"",
|
||||
"",
|
||||
map[string]string{"progress_window": "show", "version": updateVersion.String()},
|
||||
)
|
||||
|
||||
updateState := UpdateState{
|
||||
PreUpdateVersion: m.currentVersion,
|
||||
TargetVersion: updateVersion.String(),
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(updateState); err != nil {
|
||||
log.Warnf("failed to update state: %v", err)
|
||||
} else {
|
||||
if err = m.stateManager.PersistState(ctx); err != nil {
|
||||
log.Warnf("failed to persist state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
|
||||
log.Errorf("Error triggering auto-update: %v", err)
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_ERROR,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Auto-update failed",
|
||||
fmt.Sprintf("Auto-update failed: %v", err),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
|
||||
// Returns nil if no state exists.
|
||||
func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, error) {
|
||||
stateType := &UpdateState{}
|
||||
|
||||
m.stateManager.RegisterState(stateType)
|
||||
if err := m.stateManager.LoadState(stateType); err != nil {
|
||||
return nil, fmt.Errorf("load state: %w", err)
|
||||
}
|
||||
|
||||
state := m.stateManager.GetState(stateType)
|
||||
if state == nil {
|
||||
return nil, errNoUpdateState
|
||||
}
|
||||
|
||||
updateState, ok := state.(*UpdateState)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to cast state to UpdateState")
|
||||
}
|
||||
|
||||
if err := m.stateManager.DeleteState(updateState); err != nil {
|
||||
return nil, fmt.Errorf("delete state: %w", err)
|
||||
}
|
||||
|
||||
if err := m.stateManager.PersistState(ctx); err != nil {
|
||||
return nil, fmt.Errorf("persist state: %w", err)
|
||||
}
|
||||
|
||||
return updateState, nil
|
||||
}
|
||||
|
||||
func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
|
||||
if m.currentVersion == developmentVersion {
|
||||
log.Debugf("skipping auto-update, running development version")
|
||||
return false
|
||||
}
|
||||
currentVersion, err := v.NewVersion(m.currentVersion)
|
||||
if err != nil {
|
||||
log.Errorf("error checking for update, error parsing version `%s`: %v", m.currentVersion, err)
|
||||
return false
|
||||
}
|
||||
if currentVersion.GreaterThanOrEqual(updateVersion) {
|
||||
log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", m.currentVersion, updateVersion)
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Since(m.lastTrigger) < 5*time.Minute {
|
||||
log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) lastResultErrReason() string {
|
||||
inst := installer.New()
|
||||
result := installer.NewResultHandler(inst.TempDir())
|
||||
return result.GetErrorResultReason()
|
||||
}
|
||||
|
||||
func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
inst := installer.New()
|
||||
return inst.RunInstallation(ctx, targetVersion)
|
||||
}
|
||||
214
client/internal/updatemanager/manager_test.go
Normal file
214
client/internal/updatemanager/manager_test.go
Normal file
@@ -0,0 +1,214 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
type versionUpdateMock struct {
|
||||
latestVersion *v.Version
|
||||
onUpdate func()
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) StopWatch() {}
|
||||
|
||||
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
|
||||
v.onUpdate = updateFn
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) LatestVersion() *v.Version {
|
||||
return v.latestVersion
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) StartFetcher() {}
|
||||
|
||||
func Test_LatestVersion(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
initialLatestVersion *v.Version
|
||||
latestVersion *v.Version
|
||||
shouldUpdateInit bool
|
||||
shouldUpdateLater bool
|
||||
}{
|
||||
{
|
||||
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
latestVersion: v.Must(v.NewSemver("1.0.2")),
|
||||
shouldUpdateInit: true,
|
||||
shouldUpdateLater: false,
|
||||
},
|
||||
{
|
||||
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: nil,
|
||||
latestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
shouldUpdateInit: false,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, c := range testMatrix {
|
||||
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
|
||||
m.update = mockUpdate
|
||||
|
||||
targetVersionChan := make(chan string, 1)
|
||||
|
||||
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
|
||||
targetVersionChan <- targetVersion
|
||||
return nil
|
||||
}
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.Start(context.Background())
|
||||
m.SetVersion("latest")
|
||||
var triggeredInit bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if targetVersion != c.initialLatestVersion.String() {
|
||||
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
|
||||
}
|
||||
triggeredInit = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
triggeredInit = false
|
||||
}
|
||||
if triggeredInit != c.shouldUpdateInit {
|
||||
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
|
||||
}
|
||||
|
||||
mockUpdate.latestVersion = c.latestVersion
|
||||
mockUpdate.onUpdate()
|
||||
|
||||
var triggeredLater bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if targetVersion != c.latestVersion.String() {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
|
||||
}
|
||||
triggeredLater = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
triggeredLater = false
|
||||
}
|
||||
if triggeredLater != c.shouldUpdateLater {
|
||||
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleUpdate(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
latestVersion *v.Version
|
||||
expectedVersion string
|
||||
shouldUpdate bool
|
||||
}{
|
||||
{
|
||||
name: "Update to a specific version should update regardless of if latestVersion is available yet",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.56.0",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Update to specific version should not update if version matches",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.55.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Update to specific version should not update if current version is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.54.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Update to latest version should update if latest is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Update to latest version should not update if latest == current",
|
||||
daemonVersion: "0.56.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if daemon version is invalid",
|
||||
daemonVersion: "development",
|
||||
latestVersion: v.Must(v.NewSemver("1.0.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expecting latest and latest version is unavailable",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expected version is invalid",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "development",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
}
|
||||
for idx, c := range testMatrix {
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
|
||||
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
|
||||
targetVersionChan := make(chan string, 1)
|
||||
|
||||
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
|
||||
targetVersionChan <- targetVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.Start(context.Background())
|
||||
m.SetVersion(c.expectedVersion)
|
||||
|
||||
var updateTriggered bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
|
||||
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
|
||||
}
|
||||
updateTriggered = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
updateTriggered = false
|
||||
}
|
||||
|
||||
if updateTriggered != c.shouldUpdate {
|
||||
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
|
||||
}
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
39
client/internal/updatemanager/manager_unsupported.go
Normal file
39
client/internal/updatemanager/manager_unsupported.go
Normal file
@@ -0,0 +1,39 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Manager is a no-op stub for unsupported platforms
|
||||
type Manager struct{}
|
||||
|
||||
// NewManager returns a no-op manager for unsupported platforms
|
||||
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
|
||||
return nil, fmt.Errorf("update manager is not supported on this platform")
|
||||
}
|
||||
|
||||
// CheckUpdateSuccess is a no-op on unsupported platforms
|
||||
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// Start is a no-op on unsupported platforms
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// SetVersion is a no-op on unsupported platforms
|
||||
func (m *Manager) SetVersion(expectedVersion string) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// Stop is a no-op on unsupported platforms
|
||||
func (m *Manager) Stop() {
|
||||
// no-op
|
||||
}
|
||||
302
client/internal/updatemanager/reposign/artifact.go
Normal file
302
client/internal/updatemanager/reposign/artifact.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package reposign
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/blake2s"
|
||||
)
|
||||
|
||||
const (
|
||||
tagArtifactPrivate = "ARTIFACT PRIVATE KEY"
|
||||
tagArtifactPublic = "ARTIFACT PUBLIC KEY"
|
||||
|
||||
maxArtifactKeySignatureAge = 10 * 365 * 24 * time.Hour
|
||||
maxArtifactSignatureAge = 10 * 365 * 24 * time.Hour
|
||||
)
|
||||
|
||||
// ArtifactHash wraps a hash.Hash and counts bytes written
|
||||
type ArtifactHash struct {
|
||||
hash.Hash
|
||||
}
|
||||
|
||||
// NewArtifactHash returns an initialized ArtifactHash using BLAKE2s
|
||||
func NewArtifactHash() *ArtifactHash {
|
||||
h, err := blake2s.New256(nil)
|
||||
if err != nil {
|
||||
panic(err) // Should never happen with nil Key
|
||||
}
|
||||
return &ArtifactHash{Hash: h}
|
||||
}
|
||||
|
||||
func (ah *ArtifactHash) Write(b []byte) (int, error) {
|
||||
return ah.Hash.Write(b)
|
||||
}
|
||||
|
||||
// ArtifactKey is a signing Key used to sign artifacts
|
||||
type ArtifactKey struct {
|
||||
PrivateKey
|
||||
}
|
||||
|
||||
func (k ArtifactKey) String() string {
|
||||
return fmt.Sprintf(
|
||||
"ArtifactKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
|
||||
k.Metadata.ID,
|
||||
k.Metadata.CreatedAt.Format(time.RFC3339),
|
||||
k.Metadata.ExpiresAt.Format(time.RFC3339),
|
||||
)
|
||||
}
|
||||
|
||||
func GenerateArtifactKey(rootKey *RootKey, expiration time.Duration) (*ArtifactKey, []byte, []byte, []byte, error) {
|
||||
// Verify root key is still valid
|
||||
if !rootKey.Metadata.ExpiresAt.IsZero() && time.Now().After(rootKey.Metadata.ExpiresAt) {
|
||||
return nil, nil, nil, nil, fmt.Errorf("root key has expired on %s", rootKey.Metadata.ExpiresAt.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
expirationTime := now.Add(expiration)
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("generate ed25519 key: %w", err)
|
||||
}
|
||||
|
||||
metadata := KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: now.UTC(),
|
||||
ExpiresAt: expirationTime.UTC(),
|
||||
}
|
||||
|
||||
ak := &ArtifactKey{
|
||||
PrivateKey{
|
||||
Key: priv,
|
||||
Metadata: metadata,
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal PrivateKey struct to JSON
|
||||
privJSON, err := json.Marshal(ak.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
|
||||
}
|
||||
|
||||
// Marshal PublicKey struct to JSON
|
||||
pubKey := PublicKey{
|
||||
Key: pub,
|
||||
Metadata: metadata,
|
||||
}
|
||||
pubJSON, err := json.Marshal(pubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
|
||||
// Encode to PEM with metadata embedded in bytes
|
||||
privPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: tagArtifactPrivate,
|
||||
Bytes: privJSON,
|
||||
})
|
||||
|
||||
pubPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: tagArtifactPublic,
|
||||
Bytes: pubJSON,
|
||||
})
|
||||
|
||||
// Sign the public key with the root key
|
||||
signature, err := SignArtifactKey(*rootKey, pubPEM)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to sign artifact key: %w", err)
|
||||
}
|
||||
|
||||
return ak, privPEM, pubPEM, signature, nil
|
||||
}
|
||||
|
||||
func ParseArtifactKey(privKeyPEM []byte) (ArtifactKey, error) {
|
||||
pk, err := parsePrivateKey(privKeyPEM, tagArtifactPrivate)
|
||||
if err != nil {
|
||||
return ArtifactKey{}, fmt.Errorf("failed to parse artifact Key: %w", err)
|
||||
}
|
||||
return ArtifactKey{pk}, nil
|
||||
}
|
||||
|
||||
func ParseArtifactPubKey(data []byte) (PublicKey, error) {
|
||||
pk, _, err := parsePublicKey(data, tagArtifactPublic)
|
||||
return pk, err
|
||||
}
|
||||
|
||||
func BundleArtifactKeys(rootKey *RootKey, keys []PublicKey) ([]byte, []byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return nil, nil, errors.New("no keys to bundle")
|
||||
}
|
||||
|
||||
// Create bundle by concatenating PEM-encoded keys
|
||||
var pubBundle []byte
|
||||
|
||||
for _, pk := range keys {
|
||||
// Marshal PublicKey struct to JSON
|
||||
pubJSON, err := json.Marshal(pk)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
|
||||
// Encode to PEM
|
||||
pubPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: tagArtifactPublic,
|
||||
Bytes: pubJSON,
|
||||
})
|
||||
|
||||
pubBundle = append(pubBundle, pubPEM...)
|
||||
}
|
||||
|
||||
// Sign the entire bundle with the root key
|
||||
signature, err := SignArtifactKey(*rootKey, pubBundle)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to sign artifact key bundle: %w", err)
|
||||
}
|
||||
|
||||
return pubBundle, signature, nil
|
||||
}
|
||||
|
||||
func ValidateArtifactKeys(publicRootKeys []PublicKey, data []byte, signature Signature, revocationList *RevocationList) ([]PublicKey, error) {
|
||||
now := time.Now().UTC()
|
||||
if signature.Timestamp.After(now.Add(maxClockSkew)) {
|
||||
err := fmt.Errorf("signature timestamp is in the future: %v", signature.Timestamp)
|
||||
log.Debugf("artifact signature error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
if now.Sub(signature.Timestamp) > maxArtifactKeySignatureAge {
|
||||
err := fmt.Errorf("signature is too old: %v (created %v)", now.Sub(signature.Timestamp), signature.Timestamp)
|
||||
log.Debugf("artifact signature error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Reconstruct the signed message: artifact_key_data || timestamp
|
||||
msg := make([]byte, 0, len(data)+8)
|
||||
msg = append(msg, data...)
|
||||
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
|
||||
|
||||
if !verifyAny(publicRootKeys, msg, signature.Signature) {
|
||||
return nil, errors.New("failed to verify signature of artifact keys")
|
||||
}
|
||||
|
||||
pubKeys, err := parsePublicKeyBundle(data, tagArtifactPublic)
|
||||
if err != nil {
|
||||
log.Debugf("failed to parse public keys: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validKeys := make([]PublicKey, 0, len(pubKeys))
|
||||
for _, pubKey := range pubKeys {
|
||||
// Filter out expired keys
|
||||
if !pubKey.Metadata.ExpiresAt.IsZero() && now.After(pubKey.Metadata.ExpiresAt) {
|
||||
log.Debugf("Key %s is expired at %v (current time %v)",
|
||||
pubKey.Metadata.ID, pubKey.Metadata.ExpiresAt, now)
|
||||
continue
|
||||
}
|
||||
|
||||
if revocationList != nil {
|
||||
if revTime, revoked := revocationList.Revoked[pubKey.Metadata.ID]; revoked {
|
||||
log.Debugf("Key %s is revoked as of %v (created %v)",
|
||||
pubKey.Metadata.ID, revTime, pubKey.Metadata.CreatedAt)
|
||||
continue
|
||||
}
|
||||
}
|
||||
validKeys = append(validKeys, pubKey)
|
||||
}
|
||||
|
||||
if len(validKeys) == 0 {
|
||||
log.Debugf("no valid public keys found for artifact keys")
|
||||
return nil, fmt.Errorf("all %d artifact keys are revoked", len(pubKeys))
|
||||
}
|
||||
|
||||
return validKeys, nil
|
||||
}
|
||||
|
||||
func ValidateArtifact(artifactPubKeys []PublicKey, data []byte, signature Signature) error {
|
||||
// Validate signature timestamp
|
||||
now := time.Now().UTC()
|
||||
if signature.Timestamp.After(now.Add(maxClockSkew)) {
|
||||
err := fmt.Errorf("artifact signature timestamp is in the future: %v", signature.Timestamp)
|
||||
log.Debugf("failed to verify signature of artifact: %s", err)
|
||||
return err
|
||||
}
|
||||
if now.Sub(signature.Timestamp) > maxArtifactSignatureAge {
|
||||
return fmt.Errorf("artifact signature is too old: %v (created %v)",
|
||||
now.Sub(signature.Timestamp), signature.Timestamp)
|
||||
}
|
||||
|
||||
h := NewArtifactHash()
|
||||
if _, err := h.Write(data); err != nil {
|
||||
return fmt.Errorf("failed to hash artifact: %w", err)
|
||||
}
|
||||
hash := h.Sum(nil)
|
||||
|
||||
// Reconstruct the signed message: hash || length || timestamp
|
||||
msg := make([]byte, 0, len(hash)+8+8)
|
||||
msg = append(msg, hash...)
|
||||
msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
|
||||
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
|
||||
|
||||
// Find matching Key and verify
|
||||
for _, keyInfo := range artifactPubKeys {
|
||||
if keyInfo.Metadata.ID == signature.KeyID {
|
||||
// Check Key expiration
|
||||
if !keyInfo.Metadata.ExpiresAt.IsZero() &&
|
||||
signature.Timestamp.After(keyInfo.Metadata.ExpiresAt) {
|
||||
return fmt.Errorf("signing Key %s expired at %v, signature from %v",
|
||||
signature.KeyID, keyInfo.Metadata.ExpiresAt, signature.Timestamp)
|
||||
}
|
||||
|
||||
if ed25519.Verify(keyInfo.Key, msg, signature.Signature) {
|
||||
log.Debugf("artifact verified successfully with Key: %s", signature.KeyID)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("signature verification failed for Key %s", signature.KeyID)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("no signing Key found with ID %s", signature.KeyID)
|
||||
}
|
||||
|
||||
func SignData(artifactKey ArtifactKey, data []byte) ([]byte, error) {
|
||||
if len(data) == 0 { // Check happens too late
|
||||
return nil, fmt.Errorf("artifact length must be positive, got %d", len(data))
|
||||
}
|
||||
|
||||
h := NewArtifactHash()
|
||||
if _, err := h.Write(data); err != nil {
|
||||
return nil, fmt.Errorf("failed to write artifact hash: %w", err)
|
||||
}
|
||||
|
||||
timestamp := time.Now().UTC()
|
||||
|
||||
if !artifactKey.Metadata.ExpiresAt.IsZero() && timestamp.After(artifactKey.Metadata.ExpiresAt) {
|
||||
return nil, fmt.Errorf("artifact key expired at %v", artifactKey.Metadata.ExpiresAt)
|
||||
}
|
||||
|
||||
hash := h.Sum(nil)
|
||||
|
||||
// Create message: hash || length || timestamp
|
||||
msg := make([]byte, 0, len(hash)+8+8)
|
||||
msg = append(msg, hash...)
|
||||
msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
|
||||
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
|
||||
|
||||
sig := ed25519.Sign(artifactKey.Key, msg)
|
||||
|
||||
bundle := Signature{
|
||||
Signature: sig,
|
||||
Timestamp: timestamp,
|
||||
KeyID: artifactKey.Metadata.ID,
|
||||
Algorithm: "ed25519",
|
||||
HashAlgo: "blake2s",
|
||||
}
|
||||
|
||||
return json.Marshal(bundle)
|
||||
}
|
||||
1080
client/internal/updatemanager/reposign/artifact_test.go
Normal file
1080
client/internal/updatemanager/reposign/artifact_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,6 @@
|
||||
-----BEGIN ROOT PUBLIC KEY-----
|
||||
eyJLZXkiOiJoaGIxdGRDSEZNMFBuQWp1b2w2cXJ1QXRFbWFFSlg1QjFsZUNxWmpn
|
||||
V1pvPSIsIk1ldGFkYXRhIjp7ImlkIjoiOWE0OTg2NmI2MzE2MjNiNCIsImNyZWF0
|
||||
ZWRfYXQiOiIyMDI1LTExLTI0VDE3OjE1OjI4LjYyNzE3MzE3MVoiLCJleHBpcmVz
|
||||
X2F0IjoiMjAzNS0xMS0yMlQxNzoxNToyOC42MjcxNzMxNzFaIn19
|
||||
-----END ROOT PUBLIC KEY-----
|
||||
@@ -0,0 +1,6 @@
|
||||
-----BEGIN ROOT PUBLIC KEY-----
|
||||
eyJLZXkiOiJyTDByVTN2MEFOZUNmbDZraitiUUd3TE1waU5CaUJLdVBWSnZtQzgr
|
||||
ZS84PSIsIk1ldGFkYXRhIjp7ImlkIjoiMTBkNjQyZTY2N2FmMDNkNCIsImNyZWF0
|
||||
ZWRfYXQiOiIyMDI1LTExLTIwVDE3OjI5OjI5LjE4MDk0NjMxNloiLCJleHBpcmVz
|
||||
X2F0IjoiMjAyNi0xMS0yMFQxNzoyOToyOS4xODA5NDYzMTZaIn19
|
||||
-----END ROOT PUBLIC KEY-----
|
||||
174
client/internal/updatemanager/reposign/doc.go
Normal file
174
client/internal/updatemanager/reposign/doc.go
Normal file
@@ -0,0 +1,174 @@
|
||||
// Package reposign implements a cryptographic signing and verification system
|
||||
// for NetBird software update artifacts. It provides a hierarchical key
|
||||
// management system with support for key rotation, revocation, and secure
|
||||
// artifact distribution.
|
||||
//
|
||||
// # Architecture
|
||||
//
|
||||
// The package uses a two-tier key hierarchy:
|
||||
//
|
||||
// - Root Keys: Long-lived keys that sign artifact keys. These are embedded
|
||||
// in the client binary and establish the root of trust. Root keys should
|
||||
// be kept offline and highly secured.
|
||||
//
|
||||
// - Artifact Keys: Short-lived keys that sign release artifacts (binaries,
|
||||
// packages, etc.). These are rotated regularly and can be revoked if
|
||||
// compromised. Artifact keys are signed by root keys and distributed via
|
||||
// a public repository.
|
||||
//
|
||||
// This separation allows for operational flexibility: artifact keys can be
|
||||
// rotated frequently without requiring client updates, while root keys remain
|
||||
// stable and embedded in the software.
|
||||
//
|
||||
// # Cryptographic Primitives
|
||||
//
|
||||
// The package uses strong, modern cryptographic algorithms:
|
||||
// - Ed25519: Fast, secure digital signatures (no timing attacks)
|
||||
// - BLAKE2s-256: Fast cryptographic hash for artifacts
|
||||
// - SHA-256: Key ID generation
|
||||
// - JSON: Structured key and signature serialization
|
||||
// - PEM: Standard key encoding format
|
||||
//
|
||||
// # Security Features
|
||||
//
|
||||
// Timestamp Binding:
|
||||
// - All signatures include cryptographically-bound timestamps
|
||||
// - Prevents replay attacks and enforces signature freshness
|
||||
// - Clock skew tolerance: 5 minutes
|
||||
//
|
||||
// Key Expiration:
|
||||
// - All keys have expiration times
|
||||
// - Expired keys are automatically rejected
|
||||
// - Signing with an expired key fails immediately
|
||||
//
|
||||
// Key Revocation:
|
||||
// - Compromised keys can be revoked via a signed revocation list
|
||||
// - Revocation list is checked during artifact validation
|
||||
// - Revoked keys are filtered out before artifact verification
|
||||
//
|
||||
// # File Structure
|
||||
//
|
||||
// The package expects the following file layout in the key repository:
|
||||
//
|
||||
// signrepo/
|
||||
// artifact-key-pub.pem # Bundle of artifact public keys
|
||||
// artifact-key-pub.pem.sig # Root signature of the bundle
|
||||
// revocation-list.json # List of revoked key IDs
|
||||
// revocation-list.json.sig # Root signature of revocation list
|
||||
//
|
||||
// And in the artifacts repository:
|
||||
//
|
||||
// releases/
|
||||
// v0.28.0/
|
||||
// netbird-linux-amd64
|
||||
// netbird-linux-amd64.sig # Artifact signature
|
||||
// netbird-darwin-amd64
|
||||
// netbird-darwin-amd64.sig
|
||||
// ...
|
||||
//
|
||||
// # Embedded Root Keys
|
||||
//
|
||||
// Root public keys are embedded in the client binary at compile time:
|
||||
// - Production keys: certs/ directory
|
||||
// - Development keys: certsdev/ directory
|
||||
//
|
||||
// The build tag determines which keys are embedded:
|
||||
// - Production builds: //go:build !devartifactsign
|
||||
// - Development builds: //go:build devartifactsign
|
||||
//
|
||||
// This ensures that development artifacts cannot be verified using production
|
||||
// keys and vice versa.
|
||||
//
|
||||
// # Key Rotation Strategies
|
||||
//
|
||||
// Root Key Rotation:
|
||||
//
|
||||
// Root keys can be rotated without breaking existing clients by leveraging
|
||||
// the multi-key verification system. The loadEmbeddedPublicKeys function
|
||||
// reads ALL files from the certs/ directory and accepts signatures from ANY
|
||||
// of the embedded root keys.
|
||||
//
|
||||
// To rotate root keys:
|
||||
//
|
||||
// 1. Generate a new root key pair:
|
||||
// newRootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
|
||||
//
|
||||
// 2. Add the new public key to the certs/ directory as a new file:
|
||||
// certs/
|
||||
// root-pub-2024.pem # Old key (keep this!)
|
||||
// root-pub-2025.pem # New key (add this)
|
||||
//
|
||||
// 3. Build new client versions with both keys embedded. The verification
|
||||
// will accept signatures from either key.
|
||||
//
|
||||
// 4. Start signing new artifact keys with the new root key. Old clients
|
||||
// with only the old root key will reject these, but new clients with
|
||||
// both keys will accept them.
|
||||
//
|
||||
// Each file in certs/ can contain a single key or a bundle of keys (multiple
|
||||
// PEM blocks). The system will parse all keys from all files and use them
|
||||
// for verification. This provides maximum flexibility for key management.
|
||||
//
|
||||
// Important: Never remove all old root keys at once. Always maintain at least
|
||||
// one overlapping key between releases to ensure smooth transitions.
|
||||
//
|
||||
// Artifact Key Rotation:
|
||||
//
|
||||
// Artifact keys should be rotated regularly (e.g., every 90 days) using the
|
||||
// bundling mechanism. The BundleArtifactKeys function allows multiple artifact
|
||||
// keys to be bundled together in a single signed package, and ValidateArtifact
|
||||
// will accept signatures from ANY key in the bundle.
|
||||
//
|
||||
// To rotate artifact keys smoothly:
|
||||
//
|
||||
// 1. Generate a new artifact key while keeping the old one:
|
||||
// newKey, newPrivPEM, newPubPEM, newSig, err := GenerateArtifactKey(rootKey, 90 * 24 * time.Hour)
|
||||
// // Keep oldPubPEM and oldKey available
|
||||
//
|
||||
// 2. Create a bundle containing both old and new public keys
|
||||
//
|
||||
// 3. Upload the bundle and its signature to the key repository:
|
||||
// signrepo/artifact-key-pub.pem # Contains both keys
|
||||
// signrepo/artifact-key-pub.pem.sig # Root signature
|
||||
//
|
||||
// 4. Start signing new releases with the NEW key, but keep the bundle
|
||||
// unchanged. Clients will download the bundle (containing both keys)
|
||||
// and accept signatures from either key.
|
||||
//
|
||||
// Key bundle validation workflow:
|
||||
// 1. Client downloads artifact-key-pub.pem and artifact-key-pub.pem.sig
|
||||
// 2. ValidateArtifactKeys verifies the bundle signature with ANY embedded root key
|
||||
// 3. ValidateArtifactKeys parses all public keys from the bundle
|
||||
// 4. ValidateArtifactKeys filters out expired or revoked keys
|
||||
// 5. When verifying an artifact, ValidateArtifact tries each key until one succeeds
|
||||
//
|
||||
// This multi-key acceptance model enables overlapping validity periods and
|
||||
// smooth transitions without client update requirements.
|
||||
//
|
||||
// # Best Practices
|
||||
//
|
||||
// Root Key Management:
|
||||
// - Generate root keys offline on an air-gapped machine
|
||||
// - Store root private keys in hardware security modules (HSM) if possible
|
||||
// - Use separate root keys for production and development
|
||||
// - Rotate root keys infrequently (e.g., every 5-10 years)
|
||||
// - Plan for root key rotation: embed multiple root public keys
|
||||
//
|
||||
// Artifact Key Management:
|
||||
// - Rotate artifact keys regularly (e.g., every 90 days)
|
||||
// - Use separate artifact keys for different release channels if needed
|
||||
// - Revoke keys immediately upon suspected compromise
|
||||
// - Bundle multiple artifact keys to enable smooth rotation
|
||||
//
|
||||
// Signing Process:
|
||||
// - Sign artifacts in a secure CI/CD environment
|
||||
// - Never commit private keys to version control
|
||||
// - Use environment variables or secret management for keys
|
||||
// - Verify signatures immediately after signing
|
||||
//
|
||||
// Distribution:
|
||||
// - Serve keys and revocation lists from a reliable CDN
|
||||
// - Use HTTPS for all key and artifact downloads
|
||||
// - Monitor download failures and signature verification failures
|
||||
// - Keep revocation list up to date
|
||||
package reposign
|
||||
10
client/internal/updatemanager/reposign/embed_dev.go
Normal file
10
client/internal/updatemanager/reposign/embed_dev.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build devartifactsign
|
||||
|
||||
package reposign
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed certsdev
|
||||
var embeddedCerts embed.FS
|
||||
|
||||
const embeddedCertsDir = "certsdev"
|
||||
10
client/internal/updatemanager/reposign/embed_prod.go
Normal file
10
client/internal/updatemanager/reposign/embed_prod.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !devartifactsign
|
||||
|
||||
package reposign
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed certs
|
||||
var embeddedCerts embed.FS
|
||||
|
||||
const embeddedCertsDir = "certs"
|
||||
171
client/internal/updatemanager/reposign/key.go
Normal file
171
client/internal/updatemanager/reposign/key.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package reposign
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
maxClockSkew = 5 * time.Minute
|
||||
)
|
||||
|
||||
// KeyID is a unique identifier for a Key (first 8 bytes of SHA-256 of public Key)
|
||||
type KeyID [8]byte
|
||||
|
||||
// computeKeyID generates a unique ID from a public Key
|
||||
func computeKeyID(pub ed25519.PublicKey) KeyID {
|
||||
h := sha256.Sum256(pub)
|
||||
var id KeyID
|
||||
copy(id[:], h[:8])
|
||||
return id
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for KeyID
|
||||
func (k KeyID) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(k.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler for KeyID
|
||||
func (k *KeyID) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
parsed, err := ParseKeyID(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*k = parsed
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseKeyID parses a hex string (16 hex chars = 8 bytes) into a KeyID.
|
||||
func ParseKeyID(s string) (KeyID, error) {
|
||||
var id KeyID
|
||||
if len(s) != 16 {
|
||||
return id, fmt.Errorf("invalid KeyID length: got %d, want 16 hex chars (8 bytes)", len(s))
|
||||
}
|
||||
|
||||
b, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
return id, fmt.Errorf("failed to decode KeyID: %w", err)
|
||||
}
|
||||
|
||||
copy(id[:], b)
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (k KeyID) String() string {
|
||||
return fmt.Sprintf("%x", k[:])
|
||||
}
|
||||
|
||||
// KeyMetadata contains versioning and lifecycle information for a Key
|
||||
type KeyMetadata struct {
|
||||
ID KeyID `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at,omitempty"` // Optional expiration
|
||||
}
|
||||
|
||||
// PublicKey wraps a public Key with its Metadata
|
||||
type PublicKey struct {
|
||||
Key ed25519.PublicKey
|
||||
Metadata KeyMetadata
|
||||
}
|
||||
|
||||
func parsePublicKeyBundle(bundle []byte, typeTag string) ([]PublicKey, error) {
|
||||
var keys []PublicKey
|
||||
for len(bundle) > 0 {
|
||||
keyInfo, rest, err := parsePublicKey(bundle, typeTag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys = append(keys, keyInfo)
|
||||
bundle = rest
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return nil, errors.New("no keys found in bundle")
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func parsePublicKey(data []byte, typeTag string) (PublicKey, []byte, error) {
|
||||
b, rest := pem.Decode(data)
|
||||
if b == nil {
|
||||
return PublicKey{}, nil, errors.New("failed to decode PEM data")
|
||||
}
|
||||
if b.Type != typeTag {
|
||||
return PublicKey{}, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
|
||||
}
|
||||
|
||||
// Unmarshal JSON-embedded format
|
||||
var pub PublicKey
|
||||
if err := json.Unmarshal(b.Bytes, &pub); err != nil {
|
||||
return PublicKey{}, nil, fmt.Errorf("failed to unmarshal public key: %w", err)
|
||||
}
|
||||
|
||||
// Validate key length
|
||||
if len(pub.Key) != ed25519.PublicKeySize {
|
||||
return PublicKey{}, nil, fmt.Errorf("incorrect Ed25519 public key size: expected %d, got %d",
|
||||
ed25519.PublicKeySize, len(pub.Key))
|
||||
}
|
||||
|
||||
// Always recompute ID to ensure integrity
|
||||
pub.Metadata.ID = computeKeyID(pub.Key)
|
||||
|
||||
return pub, rest, nil
|
||||
}
|
||||
|
||||
type PrivateKey struct {
|
||||
Key ed25519.PrivateKey
|
||||
Metadata KeyMetadata
|
||||
}
|
||||
|
||||
func parsePrivateKey(data []byte, typeTag string) (PrivateKey, error) {
|
||||
b, rest := pem.Decode(data)
|
||||
if b == nil {
|
||||
return PrivateKey{}, errors.New("failed to decode PEM data")
|
||||
}
|
||||
if len(rest) > 0 {
|
||||
return PrivateKey{}, errors.New("trailing PEM data")
|
||||
}
|
||||
if b.Type != typeTag {
|
||||
return PrivateKey{}, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
|
||||
}
|
||||
|
||||
// Unmarshal JSON-embedded format
|
||||
var pk PrivateKey
|
||||
if err := json.Unmarshal(b.Bytes, &pk); err != nil {
|
||||
return PrivateKey{}, fmt.Errorf("failed to unmarshal private key: %w", err)
|
||||
}
|
||||
|
||||
// Validate key length
|
||||
if len(pk.Key) != ed25519.PrivateKeySize {
|
||||
return PrivateKey{}, fmt.Errorf("incorrect Ed25519 private key size: expected %d, got %d",
|
||||
ed25519.PrivateKeySize, len(pk.Key))
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
func verifyAny(publicRootKeys []PublicKey, msg, sig []byte) bool {
|
||||
// Verify with root keys
|
||||
var rootKeys []ed25519.PublicKey
|
||||
for _, r := range publicRootKeys {
|
||||
rootKeys = append(rootKeys, r.Key)
|
||||
}
|
||||
|
||||
for _, k := range rootKeys {
|
||||
if ed25519.Verify(k, msg, sig) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
636
client/internal/updatemanager/reposign/key_test.go
Normal file
636
client/internal/updatemanager/reposign/key_test.go
Normal file
@@ -0,0 +1,636 @@
|
||||
package reposign
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test KeyID functions
|
||||
|
||||
func TestComputeKeyID(t *testing.T) {
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
keyID := computeKeyID(pub)
|
||||
|
||||
// Verify it's the first 8 bytes of SHA-256
|
||||
h := sha256.Sum256(pub)
|
||||
expectedID := KeyID{}
|
||||
copy(expectedID[:], h[:8])
|
||||
|
||||
assert.Equal(t, expectedID, keyID)
|
||||
}
|
||||
|
||||
func TestComputeKeyID_Deterministic(t *testing.T) {
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Computing KeyID multiple times should give the same result
|
||||
keyID1 := computeKeyID(pub)
|
||||
keyID2 := computeKeyID(pub)
|
||||
|
||||
assert.Equal(t, keyID1, keyID2)
|
||||
}
|
||||
|
||||
func TestComputeKeyID_DifferentKeys(t *testing.T) {
|
||||
pub1, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pub2, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
keyID1 := computeKeyID(pub1)
|
||||
keyID2 := computeKeyID(pub2)
|
||||
|
||||
// Different keys should produce different IDs
|
||||
assert.NotEqual(t, keyID1, keyID2)
|
||||
}
|
||||
|
||||
func TestParseKeyID_Valid(t *testing.T) {
|
||||
hexStr := "0123456789abcdef"
|
||||
|
||||
keyID, err := ParseKeyID(hexStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
|
||||
assert.Equal(t, expected, keyID)
|
||||
}
|
||||
|
||||
func TestParseKeyID_InvalidLength(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"too short", "01234567"},
|
||||
{"too long", "0123456789abcdef00"},
|
||||
{"empty", ""},
|
||||
{"odd length", "0123456789abcde"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := ParseKeyID(tt.input)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid KeyID length")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKeyID_InvalidHex(t *testing.T) {
|
||||
invalidHex := "0123456789abcxyz" // 'xyz' are not valid hex
|
||||
|
||||
_, err := ParseKeyID(invalidHex)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to decode KeyID")
|
||||
}
|
||||
|
||||
func TestKeyID_String(t *testing.T) {
|
||||
keyID := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
|
||||
|
||||
str := keyID.String()
|
||||
assert.Equal(t, "0123456789abcdef", str)
|
||||
}
|
||||
|
||||
func TestKeyID_RoundTrip(t *testing.T) {
|
||||
original := "fedcba9876543210"
|
||||
|
||||
keyID, err := ParseKeyID(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := keyID.String()
|
||||
assert.Equal(t, original, result)
|
||||
}
|
||||
|
||||
func TestKeyID_ZeroValue(t *testing.T) {
|
||||
keyID := KeyID{}
|
||||
str := keyID.String()
|
||||
assert.Equal(t, "0000000000000000", str)
|
||||
}
|
||||
|
||||
// Test KeyMetadata
|
||||
|
||||
func TestKeyMetadata_JSONMarshaling(t *testing.T) {
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
metadata := KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
|
||||
ExpiresAt: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded KeyMetadata
|
||||
err = json.Unmarshal(jsonData, &decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, metadata.ID, decoded.ID)
|
||||
assert.Equal(t, metadata.CreatedAt.Unix(), decoded.CreatedAt.Unix())
|
||||
assert.Equal(t, metadata.ExpiresAt.Unix(), decoded.ExpiresAt.Unix())
|
||||
}
|
||||
|
||||
func TestKeyMetadata_NoExpiration(t *testing.T) {
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
metadata := KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
|
||||
ExpiresAt: time.Time{}, // Zero value = no expiration
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(metadata)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded KeyMetadata
|
||||
err = json.Unmarshal(jsonData, &decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, decoded.ExpiresAt.IsZero())
|
||||
}
|
||||
|
||||
// Test PublicKey
|
||||
|
||||
func TestPublicKey_JSONMarshaling(t *testing.T) {
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKey := PublicKey{
|
||||
Key: pub,
|
||||
Metadata: KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(pubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded PublicKey
|
||||
err = json.Unmarshal(jsonData, &decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, pubKey.Key, decoded.Key)
|
||||
assert.Equal(t, pubKey.Metadata.ID, decoded.Metadata.ID)
|
||||
}
|
||||
|
||||
// Test parsePublicKey
|
||||
|
||||
func TestParsePublicKey_Valid(t *testing.T) {
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
metadata := KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
|
||||
}
|
||||
|
||||
pubKey := PublicKey{
|
||||
Key: pub,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
jsonData, err := json.Marshal(pubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Encode to PEM
|
||||
pemData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: tagRootPublic,
|
||||
Bytes: jsonData,
|
||||
})
|
||||
|
||||
// Parse it back
|
||||
parsed, rest, err := parsePublicKey(pemData, tagRootPublic)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, rest)
|
||||
assert.Equal(t, pub, parsed.Key)
|
||||
assert.Equal(t, metadata.ID, parsed.Metadata.ID)
|
||||
}
|
||||
|
||||
func TestParsePublicKey_InvalidPEM(t *testing.T) {
|
||||
invalidPEM := []byte("not a PEM")
|
||||
|
||||
_, _, err := parsePublicKey(invalidPEM, tagRootPublic)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to decode PEM")
|
||||
}
|
||||
|
||||
func TestParsePublicKey_WrongType(t *testing.T) {
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKey := PublicKey{
|
||||
Key: pub,
|
||||
Metadata: KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(pubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Encode with wrong type
|
||||
pemData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "WRONG TYPE",
|
||||
Bytes: jsonData,
|
||||
})
|
||||
|
||||
_, _, err = parsePublicKey(pemData, tagRootPublic)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "PEM type")
|
||||
}
|
||||
|
||||
func TestParsePublicKey_InvalidJSON(t *testing.T) {
|
||||
pemData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: tagRootPublic,
|
||||
Bytes: []byte("invalid json"),
|
||||
})
|
||||
|
||||
_, _, err := parsePublicKey(pemData, tagRootPublic)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to unmarshal")
|
||||
}
|
||||
|
||||
func TestParsePublicKey_InvalidKeySize(t *testing.T) {
|
||||
// Create a public key with wrong size
|
||||
pubKey := PublicKey{
|
||||
Key: []byte{0x01, 0x02, 0x03}, // Too short
|
||||
Metadata: KeyMetadata{
|
||||
ID: KeyID{},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(pubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: tagRootPublic,
|
||||
Bytes: jsonData,
|
||||
})
|
||||
|
||||
_, _, err = parsePublicKey(pemData, tagRootPublic)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "incorrect Ed25519 public key size")
|
||||
}
|
||||
|
||||
func TestParsePublicKey_IDRecomputation(t *testing.T) {
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a public key with WRONG ID
|
||||
wrongID := KeyID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
|
||||
pubKey := PublicKey{
|
||||
Key: pub,
|
||||
Metadata: KeyMetadata{
|
||||
ID: wrongID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(pubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: tagRootPublic,
|
||||
Bytes: jsonData,
|
||||
})
|
||||
|
||||
// Parse should recompute the correct ID
|
||||
parsed, _, err := parsePublicKey(pemData, tagRootPublic)
|
||||
require.NoError(t, err)
|
||||
|
||||
correctID := computeKeyID(pub)
|
||||
assert.Equal(t, correctID, parsed.Metadata.ID)
|
||||
assert.NotEqual(t, wrongID, parsed.Metadata.ID)
|
||||
}
|
||||
|
||||
// Test parsePublicKeyBundle
|
||||
|
||||
func TestParsePublicKeyBundle_Single(t *testing.T) {
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKey := PublicKey{
|
||||
Key: pub,
|
||||
Metadata: KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(pubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: tagRootPublic,
|
||||
Bytes: jsonData,
|
||||
})
|
||||
|
||||
keys, err := parsePublicKeyBundle(pemData, tagRootPublic)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 1)
|
||||
assert.Equal(t, pub, keys[0].Key)
|
||||
}
|
||||
|
||||
func TestParsePublicKeyBundle_Multiple(t *testing.T) {
|
||||
var bundle []byte
|
||||
|
||||
// Create 3 keys
|
||||
for i := 0; i < 3; i++ {
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKey := PublicKey{
|
||||
Key: pub,
|
||||
Metadata: KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(pubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: tagRootPublic,
|
||||
Bytes: jsonData,
|
||||
})
|
||||
|
||||
bundle = append(bundle, pemData...)
|
||||
}
|
||||
|
||||
keys, err := parsePublicKeyBundle(bundle, tagRootPublic)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 3)
|
||||
}
|
||||
|
||||
func TestParsePublicKeyBundle_Empty(t *testing.T) {
|
||||
_, err := parsePublicKeyBundle([]byte{}, tagRootPublic)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no keys found")
|
||||
}
|
||||
|
||||
func TestParsePublicKeyBundle_Invalid(t *testing.T) {
|
||||
_, err := parsePublicKeyBundle([]byte("invalid data"), tagRootPublic)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// Test PrivateKey
|
||||
|
||||
func TestPrivateKey_JSONMarshaling(t *testing.T) {
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
privKey := PrivateKey{
|
||||
Key: priv,
|
||||
Metadata: KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(privKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded PrivateKey
|
||||
err = json.Unmarshal(jsonData, &decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, privKey.Key, decoded.Key)
|
||||
assert.Equal(t, privKey.Metadata.ID, decoded.Metadata.ID)
|
||||
}
|
||||
|
||||
// Test parsePrivateKey
|
||||
|
||||
func TestParsePrivateKey_Valid(t *testing.T) {
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
privKey := PrivateKey{
|
||||
Key: priv,
|
||||
Metadata: KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(privKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: tagRootPrivate,
|
||||
Bytes: jsonData,
|
||||
})
|
||||
|
||||
parsed, err := parsePrivateKey(pemData, tagRootPrivate)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, priv, parsed.Key)
|
||||
}
|
||||
|
||||
func TestParsePrivateKey_InvalidPEM(t *testing.T) {
|
||||
_, err := parsePrivateKey([]byte("not a PEM"), tagRootPrivate)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to decode PEM")
|
||||
}
|
||||
|
||||
func TestParsePrivateKey_TrailingData(t *testing.T) {
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
privKey := PrivateKey{
|
||||
Key: priv,
|
||||
Metadata: KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(privKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: tagRootPrivate,
|
||||
Bytes: jsonData,
|
||||
})
|
||||
|
||||
// Add trailing data
|
||||
pemData = append(pemData, []byte("extra data")...)
|
||||
|
||||
_, err = parsePrivateKey(pemData, tagRootPrivate)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "trailing PEM data")
|
||||
}
|
||||
|
||||
func TestParsePrivateKey_WrongType(t *testing.T) {
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
privKey := PrivateKey{
|
||||
Key: priv,
|
||||
Metadata: KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(privKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "WRONG TYPE",
|
||||
Bytes: jsonData,
|
||||
})
|
||||
|
||||
_, err = parsePrivateKey(pemData, tagRootPrivate)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "PEM type")
|
||||
}
|
||||
|
||||
func TestParsePrivateKey_InvalidKeySize(t *testing.T) {
|
||||
privKey := PrivateKey{
|
||||
Key: []byte{0x01, 0x02, 0x03}, // Too short
|
||||
Metadata: KeyMetadata{
|
||||
ID: KeyID{},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(privKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemData := pem.EncodeToMemory(&pem.Block{
|
||||
Type: tagRootPrivate,
|
||||
Bytes: jsonData,
|
||||
})
|
||||
|
||||
_, err = parsePrivateKey(pemData, tagRootPrivate)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
|
||||
}
|
||||
|
||||
// Test verifyAny
|
||||
|
||||
func TestVerifyAny_ValidSignature(t *testing.T) {
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
message := []byte("test message")
|
||||
signature := ed25519.Sign(priv, message)
|
||||
|
||||
rootKeys := []PublicKey{
|
||||
{
|
||||
Key: pub,
|
||||
Metadata: KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := verifyAny(rootKeys, message, signature)
|
||||
assert.True(t, result)
|
||||
}
|
||||
|
||||
func TestVerifyAny_InvalidSignature(t *testing.T) {
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
message := []byte("test message")
|
||||
invalidSignature := make([]byte, ed25519.SignatureSize)
|
||||
|
||||
rootKeys := []PublicKey{
|
||||
{
|
||||
Key: pub,
|
||||
Metadata: KeyMetadata{
|
||||
ID: computeKeyID(pub),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := verifyAny(rootKeys, message, invalidSignature)
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
func TestVerifyAny_MultipleKeys(t *testing.T) {
|
||||
// Create 3 key pairs
|
||||
pub1, priv1, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pub2, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pub3, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
message := []byte("test message")
|
||||
signature := ed25519.Sign(priv1, message)
|
||||
|
||||
rootKeys := []PublicKey{
|
||||
{Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
|
||||
{Key: pub1, Metadata: KeyMetadata{ID: computeKeyID(pub1)}}, // Correct key in middle
|
||||
{Key: pub3, Metadata: KeyMetadata{ID: computeKeyID(pub3)}},
|
||||
}
|
||||
|
||||
result := verifyAny(rootKeys, message, signature)
|
||||
assert.True(t, result)
|
||||
}
|
||||
|
||||
func TestVerifyAny_NoMatchingKey(t *testing.T) {
|
||||
_, priv1, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pub2, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
message := []byte("test message")
|
||||
signature := ed25519.Sign(priv1, message)
|
||||
|
||||
// Only include pub2, not pub1
|
||||
rootKeys := []PublicKey{
|
||||
{Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
|
||||
}
|
||||
|
||||
result := verifyAny(rootKeys, message, signature)
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
func TestVerifyAny_EmptyKeys(t *testing.T) {
|
||||
message := []byte("test message")
|
||||
signature := make([]byte, ed25519.SignatureSize)
|
||||
|
||||
result := verifyAny([]PublicKey{}, message, signature)
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
func TestVerifyAny_TamperedMessage(t *testing.T) {
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
message := []byte("test message")
|
||||
signature := ed25519.Sign(priv, message)
|
||||
|
||||
rootKeys := []PublicKey{
|
||||
{Key: pub, Metadata: KeyMetadata{ID: computeKeyID(pub)}},
|
||||
}
|
||||
|
||||
// Verify with different message
|
||||
tamperedMessage := []byte("different message")
|
||||
result := verifyAny(rootKeys, tamperedMessage, signature)
|
||||
assert.False(t, result)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user