Compare commits

...

62 Commits

Author SHA1 Message Date
crn4
ca432ff681 refactor components creation 2026-01-08 17:34:53 +01:00
crn4
7b5d7aeb2e refactor 2026-01-08 17:19:16 +01:00
crn4
3bdce8d0b6 added support for ssh auth users to components 2026-01-08 16:36:40 +01:00
crn4
d534ce9dfc minor changes to benchmarks 2026-01-08 15:06:02 +01:00
crn4
bbc2b42807 changes after main merge 2026-01-08 13:44:15 +01:00
crn4
db9cc52c96 conflicts resolution 2026-01-08 13:25:10 +01:00
Diego Noguês
fb71b0d04b [infrastructure] fix: disable Caddy debug (#5067) 2026-01-08 12:49:45 +01:00
Maycon Santos
ab7d6b2196 [misc] add new getting started to release (#5057) 2026-01-08 12:12:50 +01:00
Maycon Santos
9c5b2575e3 [misc] add embedded provider support metrics
count local vs idp users if embedded
2026-01-08 12:12:19 +01:00
Bethuel Mmbaga
00e2689ffb [management] Fix race condition in experimental network map when deleting account (#5064) 2026-01-08 14:10:09 +03:00
Misha Bragin
cf535f8c61 [management] Fix role change in transaction and update readme (#5060) 2026-01-08 12:07:59 +01:00
Maycon Santos
24df442198 Revert "[relay] Update GO version and QUIC version (#4736)" (#5055)
This reverts commit 8722b79799.
2026-01-07 19:02:20 +01:00
Zoltan Papp
8722b79799 [relay] Update GO version and QUIC version (#4736)
- Go 1.25.5
- QUIC 0.55.0
2026-01-07 16:30:29 +01:00
Vlad
afcdef6121 [management] add ssh authorized users to network map cache (#5048) 2026-01-07 15:53:18 +01:00
Zoltan Papp
12a7fa24d7 Add support for disabling eBPF WireGuard proxy via environment variable (#5047) 2026-01-07 15:34:52 +01:00
Zoltan Papp
6ff9aa0366 Refactor SSH server to manage listener lifecycle and expose active address via Addr method. (#5036) 2026-01-07 15:34:26 +01:00
Misha Bragin
e586c20e36 [management, infrastructure, idp] Simplified IdP Management - Embedded IdP (#5008)
Embed Dex as a built-in IdP to simplify self-hosting setup.
Adds an embedded OIDC Identity Provider (Dex) with local user management and optional external IdP connectors (Google/GitHub/OIDC/SAML), plus device-auth flow for CLI login. Introduces instance onboarding/setup endpoints (including owner creation), field-level encryption for sensitive user data, a streamlined self-hosting provisioning script, and expanded APIs + test coverage for IdP management.

more at https://github.com/netbirdio/netbird/pull/5008#issuecomment-3718987393
2026-01-07 14:52:32 +01:00
Pascal Fischer
5393ad948f [management] fix nil handling for extra settings (#5049) 2026-01-07 13:05:39 +01:00
Bethuel Mmbaga
20d6beff1b [management] Increment network serial on peer update (#5051)
Increment the serial on peer update and prevent double serial increments and account updates when updating a user while there are peers set to expire
2026-01-07 14:59:49 +03:00
Bethuel Mmbaga
d35b7d675c [management] Refactor integrated peer deletion (#5042) 2026-01-07 14:00:39 +03:00
Viktor Liu
f012fb8592 [client] Add port forwarding to ssh proxy (#5031)
* Implement port forwarding for the ssh proxy

* Allow user switching for port forwarding
2026-01-07 12:18:04 +08:00
Vlad
7142d45ef3 [management] network map builder concurrent batch processing for peer updates (#5040) 2026-01-06 19:25:55 +01:00
crn4
3209b241d9 minor opts 2026-01-06 12:03:01 +01:00
Dennis Schridde
9bd578d4ea Fix ui-post-install.sh to use the full username (#4809)
Fixes #4808 by extracting the full username by:

- Get PID using pgrep
- Get UID from PID using /proc/${PID}/loginuid
- Get user name from UID using id
Also replaces "complex" pipe from ps to sed with a (hopefully) "simpler" (as in requiring less knowledge about the arguments of ps and regexps) invocation of cat and id.
2026-01-06 11:36:19 +01:00
Pascal Fischer
f022e34287 [shared] allow setting a user agent for the rest client (#5037) 2026-01-06 10:52:36 +01:00
Bethuel Mmbaga
7bb4fc3450 [management] Refactor integrated peer validator (#5035) 2026-01-05 20:55:22 +03:00
Maycon Santos
07856f516c [client] Fix/stuck connecting when can't access api.netbird.io (#5033)
- Connect on daemon start only if the file existed before
- fixed a bug that happened when the default profile config was removed, which would recreate it and reset the active profile to the default.
2026-01-05 13:53:17 +01:00
Zoltan Papp
08b782d6ba [client] Fix update download url (#5023) 2026-01-03 20:05:38 +03:00
Maycon Santos
80a312cc9c [client] add verbose flag for free ad tests (#5021)
add verbose flag for free ad tests
2026-01-03 11:32:41 +01:00
Zoltan Papp
9ba067391f [client] Fix semaphore slot leaks (#5018)
- Remove WaitGroup, make SemaphoreGroup a pure semaphore
- Make Add() return error instead of silently failing on context cancel
- Remove context parameter from Done() to prevent slot leaks
- Fix missing Done() call in conn.go error path
2026-01-03 09:10:02 +01:00
Pascal Fischer
7ac65bf1ad [management] Fix/delete groups without lock (#5012) 2025-12-31 11:53:20 +01:00
Zoltan Papp
2e9c316852 Fix UI stuck in "Connecting" state when daemon reports "Connected" status. (#5014)
The UI can get stuck showing "Connecting" status even after the daemon successfully connects and reports "Connected" status. This occurs because the condition to update the UI to "Connected" state checks the wrong flag.
2025-12-31 11:50:43 +01:00
shuuri-labs
96cdd56902 Feat/add support for forcing device auth flow on ios (#4944)
* updates to client file writing

* numerous

* minor

* - Align OnLoginSuccess behavior with Android (only call on nil error)
- Remove verbose debug logging from WaitToken in device_flow.go
- Improve TUN FD=0 fallback comments and warning messages
- Document why config save after login differs from Android

* Add nolint directive for staticcheck SA1029 in login.go

* Fix CodeRabbit review issues for iOS/tvOS SDK

- Remove goroutine from OnLoginSuccess callback, invoke synchronously
- Stop treating PermissionDenied as success, propagate as permanent error
- Replace context.TODO() with bounded timeout context (30s) in RequestAuthInfo
- Handle DirectUpdateOrCreateConfig errors in IsLoginRequired and LoginForMobile
- Add permission enforcement to DirectUpdateOrCreateConfig for existing configs
- Fix variable shadowing in device_ios.go where err was masked by := in else block

* Address additional CodeRabbit review issues for iOS/tvOS SDK

- Make tunFd == 0 a hard error with exported ErrInvalidTunnelFD (remove dead fallback code)
- Apply defaults in ConfigFromJSON to prevent partially-initialized configs
- Add nil guards for listener/urlOpener interfaces in public SDK entry points
- Reorder config save before OnLoginSuccess to prevent teardown race
- Add explanatory comment for urlOpener.Open goroutine

* Make urlOpener.Open() synchronous in device auth flow
2025-12-30 16:41:36 +00:00
Misha Bragin
9ed1437442 Add DEX IdP Support (#4949) 2025-12-30 07:42:34 -05:00
Pascal Fischer
a8604ef51c [management] filter own peer when having a group to peer policy to themself (#4956) 2025-12-30 10:49:43 +01:00
Nicolas Henneaux
d88e046d00 fix(router): nft tables limit number of peers source (#4852)
* fix(router): nft tables limit number of peers source batching them, failing at 3277 prefixes on nftables v1.0.9 with Ubuntu 24.04.3 LTS,  6.14.0-35-generic #35~24.04.1-Ubuntu

* fix(router): nft tables limit number of prefixes on ipSet creation
2025-12-30 10:48:17 +01:00
Pascal Fischer
1d2c7776fd [management] apply login filter only for setup key peers (#4943) 2025-12-30 10:46:00 +01:00
Haruki Hasegawa
4035f07248 [client] Fix Advanced Settings not opening on Windows with Japanese locale (#4455) (#4637)
The Fyne framework does not support TTC font files.
Use the default system font (Segoe UI) instead, so Windows can
automatically fall back to a Japanese font when needed.
2025-12-30 10:36:12 +01:00
Zoltan Papp
ef2721f4e1 Filter out own peer from remote peers list during peer updates. (#4986) 2025-12-30 10:29:45 +01:00
Louis Li
e11970e32e [client] add reset for management backoff (#4935)
Reset client management grpc client backoff after successful connected to management API.

Current Situation:
If the connection duration exceeds MaxElapsedTime, when the connection is interrupted, the backoff fails immediately due to timeout and does not actually perform a retry.
2025-12-30 08:37:49 +01:00
crn4
7566afd7d0 components approach - we are sending all components needed for nmap assembling on client side 2025-12-29 17:31:16 +01:00
Maycon Santos
38f9d5ed58 [infra] Preset signal port on templates (#5004)
When passing certificates to signal, it will select port 443 when no port is supplied. This changes forces port 80.
2025-12-29 18:07:06 +03:00
Pascal Fischer
b6a327e0c9 [management] fix scanning authorized user on policy rule (#5002) 2025-12-29 15:03:16 +01:00
Zoltan Papp
67f7b2404e [client, management] Feature/ssh fine grained access (#4969)
Add fine-grained SSH access control with authorized users/groups
2025-12-29 12:50:41 +01:00
Zoltan Papp
73201c4f3e Add conditional checks for FreeBSD diff file generation in release workflow (#5001) 2025-12-29 12:47:38 +01:00
Carlos Hernandez
33d1761fe8 Apply DNS host config on change only (#4695)
Adds a per-instance uint64 hash to DefaultServer to detect identical merged host DNS configs (including extra domains). applyHostConfig computes and compares the hash, skips applying if unchanged, treats hash errors as a fail-safe (proceed to apply), and updates the stored hash only after successful hashing and apply.
2025-12-29 12:43:57 +01:00
August
aa914a0f26 [docs] Fix broken image link (#4876) 2025-12-24 22:06:35 +05:00
Maycon Santos
ab6a9e85de [misc] Use new sign pipelines 0.1.0 (#4993) 2025-12-24 22:03:14 +05:00
Maycon Santos
d3b123c76d [ci] Add FreeBSD port release job to GitHub Actions (#4916)
adds a job that produces new freebsd release files
2025-12-24 11:22:33 +01:00
Viktor Liu
fc4932a23f [client] Fix Linux UI flickering on state updates (#4886) 2025-12-24 11:06:13 +01:00
Zoltan Papp
b7e98acd1f [client] Android profile switch (#4884)
Expose the profile-manager service for Android. Logout was not part of the manager service implementation. In the future, I recommend moving this logic there.
2025-12-22 22:09:05 +01:00
Maycon Santos
433bc4ead9 [client] lookup for management domains using an additional timeout (#4983)
in some cases iOS and macOS may be locked when looking for management domains during network changes

This change introduce an additional timeout on top of the context call
2025-12-22 20:04:52 +01:00
Zoltan Papp
011cc81678 [client, management] auto-update (#4732) 2025-12-19 19:57:39 +01:00
crn4
e93d4132d3 Merge branch 'main' into nmap/compaction 2025-12-18 16:50:41 +01:00
Zoltan Papp
537151e0f3 Remove redundant lock in peer update logic to avoid deadlock with exported functions (#4953) 2025-12-17 13:55:33 +01:00
Zoltan Papp
a9c28ef723 Add stack trace for bundle (#4957) 2025-12-17 13:49:02 +01:00
Pascal Fischer
c29bb1a289 [management] use xid as request id for logging (#4955) 2025-12-16 14:02:37 +01:00
crn4
21e5e6ddff cached compaction 2025-12-05 00:29:58 +01:00
crn4
10fb18736b benchmark on both maps was added 2025-12-05 00:29:58 +01:00
crn4
942abeca0c select nmap based on peer version 2025-12-05 00:29:58 +01:00
crn4
e184a43e8a firewallrules compacted 2025-12-05 00:29:58 +01:00
crn4
f33f84299f routes compacted 2025-12-05 00:29:58 +01:00
243 changed files with 27178 additions and 2453 deletions

View File

@@ -39,7 +39,7 @@ jobs:
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 8m -failfast -v -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.23"
SIGN_PIPE_VER: "v0.1.0"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -19,6 +19,100 @@ concurrency:
cancel-in-progress: true
jobs:
release_freebsd_port:
name: "FreeBSD Port / Build & Test"
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Generate FreeBSD port diff
run: bash release_files/freebsd-port-diff.sh
- name: Generate FreeBSD port issue body
run: bash release_files/freebsd-port-issue-body.sh
- name: Check if diff was generated
id: check_diff
run: |
if ls netbird-*.diff 1> /dev/null 2>&1; then
echo "diff_exists=true" >> $GITHUB_OUTPUT
else
echo "diff_exists=false" >> $GITHUB_OUTPUT
echo "No diff file generated (port may already be up to date)"
fi
- name: Extract version
if: steps.check_diff.outputs.diff_exists == 'true'
id: version
run: |
VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/')
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Generated files for version: $VERSION"
cat netbird-*.diff
- name: Test FreeBSD port
if: steps.check_diff.outputs.diff_exists == 'true'
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "15.0"
prepare: |
# Install required packages
pkg install -y git curl portlint go
# Install Go for building
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"
# Clone ports tree (shallow, only what we need)
git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports
cd /usr/ports
run: |
set -e -x
export PATH=$PATH:/usr/local/go/bin
# Find the diff file
echo "Finding diff file..."
DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1)
echo "Found: $DIFF_FILE"
if [[ -z "$DIFF_FILE" ]]; then
echo "ERROR: Could not find diff file"
find ~ -name "*.diff" -type f 2>/dev/null || true
exit 1
fi
# Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths)
cd /usr/ports
patch -p1 -V none < "$DIFF_FILE"
# Show patched Makefile
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
cd /usr/ports/security/netbird
export BATCH=yes
make package
pkg add ./work/pkg/netbird-*.pkg
netbird version | grep "$version"
echo "FreeBSD port test completed successfully!"
- name: Upload FreeBSD port files
if: steps.check_diff.outputs.diff_exists == 'true'
uses: actions/upload-artifact@v4
with:
name: freebsd-port-files
path: |
./netbird-*-issue.txt
./netbird-*.diff
retention-days: 30
release:
runs-on: ubuntu-latest-m
env:

View File

@@ -243,6 +243,7 @@ jobs:
working-directory: infrastructure_files/artifacts
run: |
sleep 30
docker compose logs
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db

1
.gitignore vendored
View File

@@ -31,3 +31,4 @@ infrastructure_files/setup-*.env
.DS_Store
vendor/
/netbird
client/netbird-electron/

View File

@@ -713,8 +713,10 @@ checksum:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh
release:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh

View File

@@ -85,7 +85,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
**Infrastructure requirements:**
- A Linux VM with at least **1CPU** and **2GB** of memory.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP ports: **3478**, **49152-65535**.
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
- **Public domain** name pointing to the VM.
**Software requirements:**
@@ -98,7 +98,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird
**Steps**
- Download and run the installation script:
```bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started-with-zitadel.sh | bash
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
```
- Once finished, you can manage the resources via `docker-compose`
@@ -113,7 +113,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle">
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
</p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.

View File

@@ -59,7 +59,6 @@ func init() {
// Client struct manage the life circle of background service
type Client struct {
cfgFile string
tunAdapter device.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
@@ -68,18 +67,16 @@ type Client struct {
deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener
stateFile string
connectClient *internal.ConnectClient
}
// NewClient instantiate a new Client
func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
cfgFile: platformFiles.ConfigurationFilePath(),
deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter,
@@ -87,15 +84,20 @@ func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName st
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
networkChangeListener: networkChangeListener,
stateFile: platformFiles.StateFilePath(),
}
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
ConfigPath: cfgFile,
})
if err != nil {
return err
@@ -122,16 +124,22 @@ func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsRea
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
ConfigPath: cfgFile,
})
if err != nil {
return err
@@ -149,8 +157,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// Stop the internal client and free the resources

View File

@@ -0,0 +1,257 @@
//go:build android
package android
import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
const (
// Android-specific config filename (different from desktop default.json)
defaultConfigFilename = "netbird.cfg"
// Subdirectory for non-default profiles (must match Java Preferences.java)
profilesSubdir = "profiles"
// Android uses a single user context per app (non-empty username required by ServiceManager)
androidUsername = "android"
)
// Profile represents a profile for gomobile
type Profile struct {
Name string
IsActive bool
}
// ProfileArray wraps profiles for gomobile compatibility
type ProfileArray struct {
items []*Profile
}
// Length returns the number of profiles
func (p *ProfileArray) Length() int {
return len(p.items)
}
// Get returns the profile at index i
func (p *ProfileArray) Get(i int) *Profile {
if i < 0 || i >= len(p.items) {
return nil
}
return p.items[i]
}
/*
/data/data/io.netbird.client/files/ ← configDir parameter
├── netbird.cfg ← Default profile config
├── state.json ← Default profile state
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
└── profiles/ ← Subdirectory for non-default profiles
├── work.json ← Work profile config
├── work.state.json ← Work profile state
├── personal.json ← Personal profile config
└── personal.state.json ← Personal profile state
*/
// ProfileManager manages profiles for Android
// It wraps the internal profilemanager to provide Android-specific behavior
type ProfileManager struct {
configDir string
serviceMgr *profilemanager.ServiceManager
}
// NewProfileManager creates a new profile manager for Android
func NewProfileManager(configDir string) *ProfileManager {
// Set the default config path for Android (stored in root configDir, not profiles/)
defaultConfigPath := filepath.Join(configDir, defaultConfigFilename)
// Set global paths for Android
profilemanager.DefaultConfigPathDir = configDir
profilemanager.DefaultConfigPath = defaultConfigPath
profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json")
// Create ServiceManager with profiles/ subdirectory
// This avoids modifying the global ConfigDirOverride for profile listing
profilesDir := filepath.Join(configDir, profilesSubdir)
serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir)
return &ProfileManager{
configDir: configDir,
serviceMgr: serviceMgr,
}
}
// ListProfiles returns all available profiles
func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
// Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive)
internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername)
if err != nil {
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
// Convert internal profiles to Android Profile type
var profiles []*Profile
for _, p := range internalProfiles {
profiles = append(profiles, &Profile{
Name: p.Name,
IsActive: p.IsActive,
})
}
return &ProfileArray{items: profiles}, nil
}
// GetActiveProfile returns the currently active profile name
func (pm *ProfileManager) GetActiveProfile() (string, error) {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
activeState, err := pm.serviceMgr.GetActiveProfileState()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return activeState.Name, nil
}
// SwitchProfile switches to a different profile
func (pm *ProfileManager) SwitchProfile(profileName string) error {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profileName,
Username: androidUsername,
})
if err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
log.Infof("switched to profile: %s", profileName)
return nil
}
// AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory)
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to add profile: %w", err)
}
log.Infof("created new profile: %s", profileName)
return nil
}
// LogoutProfile logs out from a profile (clears authentication)
func (pm *ProfileManager) LogoutProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
configPath, err := pm.getProfileConfigPath(profileName)
if err != nil {
return err
}
// Check if profile exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("profile '%s' does not exist", profileName)
}
// Read current config using internal profilemanager
config, err := profilemanager.ReadConfig(configPath)
if err != nil {
return fmt.Errorf("failed to read profile config: %w", err)
}
// Clear authentication by removing private key and SSH key
config.PrivateKey = ""
config.SSHKey = ""
// Save config using internal profilemanager
if err := profilemanager.WriteOutConfig(configPath, config); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
log.Infof("logged out from profile: %s", profileName)
return nil
}
// RemoveProfile deletes a profile
func (pm *ProfileManager) RemoveProfile(profileName string) error {
// Use ServiceManager (removes profile from profiles/ directory)
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to remove profile: %w", err)
}
log.Infof("removed profile: %s", profileName)
return nil
}
// getProfileConfigPath returns the config file path for a profile
// This is needed for Android-specific path handling (netbird.cfg for default profile)
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
// Android uses netbird.cfg for default profile instead of default.json
// Default profile is stored in root configDir, not in profiles/
return filepath.Join(pm.configDir, defaultConfigFilename), nil
}
// Non-default profiles are stored in profiles subdirectory
// This matches the Java Preferences.java expectation
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".json"), nil
}
// GetConfigPath returns the config file path for a given profile
// Java should call this instead of constructing paths with Preferences.configFile()
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
return pm.getProfileConfigPath(profileName)
}
// GetStateFilePath returns the state file path for a given profile
// Java should call this instead of constructing paths with Preferences.stateFile()
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
return filepath.Join(pm.configDir, "state.json"), nil
}
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".state.json"), nil
}
// GetActiveConfigPath returns the config file path for the currently active profile
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile()
func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
activeProfile, err := pm.GetActiveProfile()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetConfigPath(activeProfile)
}
// GetActiveStateFilePath returns the state file path for the currently active profile
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile()
func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
activeProfile, err := pm.GetActiveProfile()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetStateFilePath(activeProfile)
}
// sanitizeProfileName removes invalid characters from profile name
func sanitizeProfileName(name string) string {
// Keep only alphanumeric, underscore, and hyphen
var result strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}

View File

@@ -85,6 +85,9 @@ var (
// Execute executes the root command.
func Execute() error {
if isUpdateBinary() {
return updateCmd.Execute()
}
return rootCmd.Execute()
}

View File

@@ -0,0 +1,176 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
bundlePubKeysRootPrivKeyFile string
bundlePubKeysPubKeyFiles []string
bundlePubKeysFile string
createArtifactKeyRootPrivKeyFile string
createArtifactKeyPrivKeyFile string
createArtifactKeyPubKeyFile string
createArtifactKeyExpiration time.Duration
)
var createArtifactKeyCmd = &cobra.Command{
Use: "create-artifact-key",
Short: "Create a new artifact signing key",
Long: `Generate a new artifact signing key pair signed by the root private key.
The artifact key will be used to sign software artifacts/updates.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if createArtifactKeyExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil {
return fmt.Errorf("failed to create artifact key: %w", err)
}
return nil
},
}
var bundlePubKeysCmd = &cobra.Command{
Use: "bundle-pub-keys",
Short: "Bundle multiple artifact public keys into a signed package",
Long: `Bundle one or more artifact public keys into a signed package using the root private key.
This command is typically used to distribute or authorize a set of valid artifact signing keys.`,
RunE: func(cmd *cobra.Command, args []string) error {
if len(bundlePubKeysPubKeyFiles) == 0 {
return fmt.Errorf("at least one --artifact-pub-key-file must be provided")
}
if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil {
return fmt.Errorf("failed to bundle public keys: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createArtifactKeyCmd)
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved")
createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)")
if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(fmt.Errorf("mark expiration as required: %w", err))
}
rootCmd.AddCommand(bundlePubKeysCmd)
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle")
bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)")
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved")
if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil {
panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err))
}
}
func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error {
cmd.Println("Creating new artifact signing key...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration)
if err != nil {
return fmt.Errorf("generate artifact key: %w", err)
}
if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err)
}
if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err)
}
signatureFile := artifactPubKeyFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Artifact key created successfully.\n")
cmd.Printf("%s\n", artifactKey.String())
return nil
}
func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error {
cmd.Println("📦 Bundling public keys into signed package...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles))
for _, pubFile := range artifactPubKeyFiles {
pubPem, err := os.ReadFile(pubFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
pk, err := reposign.ParseArtifactPubKey(pubPem)
if err != nil {
return fmt.Errorf("failed to parse artifact key: %w", err)
}
publicKeys = append(publicKeys, pk)
}
parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys)
if err != nil {
return fmt.Errorf("bundle artifact keys: %w", err)
}
if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil {
return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err)
}
signatureFile := bundlePubKeysFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles))
return nil
}

View File

@@ -0,0 +1,276 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY"
)
var (
signArtifactPrivKeyFile string
signArtifactArtifactFile string
verifyArtifactPubKeyFile string
verifyArtifactFile string
verifyArtifactSignatureFile string
verifyArtifactKeyPubKeyFile string
verifyArtifactKeyRootPubKeyFile string
verifyArtifactKeySignatureFile string
verifyArtifactKeyRevocationFile string
)
var signArtifactCmd = &cobra.Command{
Use: "sign-artifact",
Short: "Sign an artifact using an artifact private key",
Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key.
This command produces a detached signature that can be verified using the corresponding artifact public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil {
return fmt.Errorf("failed to sign artifact: %w", err)
}
return nil
},
}
var verifyArtifactCmd = &cobra.Command{
Use: "verify-artifact",
Short: "Verify an artifact signature using an artifact public key",
Long: `Verify a software artifact signature using the artifact's public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil {
return fmt.Errorf("failed to verify artifact: %w", err)
}
return nil
},
}
var verifyArtifactKeyCmd = &cobra.Command{
Use: "verify-artifact-key",
Short: "Verify an artifact public key was signed by a root key",
Long: `Verify that an artifact public key (or bundle) was properly signed by a root key.
This validates the chain of trust from the root key to the artifact key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil {
return fmt.Errorf("failed to verify artifact key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(signArtifactCmd)
rootCmd.AddCommand(verifyArtifactCmd)
rootCmd.AddCommand(verifyArtifactKeyCmd)
signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey))
signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed")
// artifact-file is required, but artifact-key-file can come from env var
if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file")
if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)")
if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil {
panic(fmt.Errorf("mark root-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
}
func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error {
cmd.Println("🖋️ Signing artifact...")
// Load private key from env var or file
var privKeyPEM []byte
var err error
if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" {
// Use key from environment variable
privKeyPEM = []byte(envKey)
} else if privKeyFile != "" {
// Fall back to file
privKeyPEM, err = os.ReadFile(privKeyFile)
if err != nil {
return fmt.Errorf("read private key file: %w", err)
}
} else {
return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey)
}
privateKey, err := reposign.ParseArtifactKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact private key: %w", err)
}
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
signature, err := reposign.SignData(privateKey, artifactData)
if err != nil {
return fmt.Errorf("sign artifact: %w", err)
}
sigFile := artifactFile + ".sig"
if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", sigFile, err)
}
cmd.Printf("✅ Artifact signed successfully.\n")
cmd.Printf("Signature file: %s\n", sigFile)
return nil
}
func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error {
cmd.Println("🔍 Verifying artifact...")
// Read artifact public key
pubKeyPEM, err := os.ReadFile(pubKeyFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact public key: %w", err)
}
// Read artifact data
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate artifact
if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil {
return fmt.Errorf("artifact verification failed: %w", err)
}
cmd.Println("✅ Artifact signature is valid")
cmd.Printf("Artifact: %s\n", artifactFile)
cmd.Printf("Signed by key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
return nil
}
func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error {
cmd.Println("🔍 Verifying artifact key...")
// Read artifact key data
artifactKeyData, err := os.ReadFile(artifactKeyFile)
if err != nil {
return fmt.Errorf("read artifact key file: %w", err)
}
// Read root public key(s)
rootKeyData, err := os.ReadFile(rootKeyFile)
if err != nil {
return fmt.Errorf("read root key file: %w", err)
}
rootPublicKeys, err := parseRootPublicKeys(rootKeyData)
if err != nil {
return fmt.Errorf("failed to parse root public key(s): %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Read optional revocation list
var revocationList *reposign.RevocationList
if revocationFile != "" {
revData, err := os.ReadFile(revocationFile)
if err != nil {
return fmt.Errorf("read revocation file: %w", err)
}
revocationList, err = reposign.ParseRevocationList(revData)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
}
// Validate artifact key(s)
validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList)
if err != nil {
return fmt.Errorf("artifact key verification failed: %w", err)
}
cmd.Println("✅ Artifact key(s) verified successfully")
cmd.Printf("Signed by root key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys))
for i, key := range validKeys {
cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID)
cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST"))
if !key.Metadata.ExpiresAt.IsZero() {
cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST"))
} else {
cmd.Printf(" Expires: Never\n")
}
}
return nil
}
// parseRootPublicKeys parses a root public key from PEM data
func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) {
key, err := reposign.ParseRootPublicKey(data)
if err != nil {
return nil, err
}
return []reposign.PublicKey{key}, nil
}

21
client/cmd/signer/main.go Normal file
View File

@@ -0,0 +1,21 @@
package main
import (
"os"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "signer",
Short: "A CLI tool for managing cryptographic keys and artifacts",
Long: `signer is a command-line tool that helps you manage
root keys, artifact keys, and revocation lists securely.`,
}
func main() {
if err := rootCmd.Execute(); err != nil {
rootCmd.Println(err)
os.Exit(1)
}
}

View File

@@ -0,0 +1,220 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year
)
var (
keyID string
revocationListFile string
privateRootKeyFile string
publicRootKeyFile string
signatureFile string
expirationDuration time.Duration
)
var createRevocationListCmd = &cobra.Command{
Use: "create-revocation-list",
Short: "Create a new revocation list signed by the private root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile)
},
}
var extendRevocationListCmd = &cobra.Command{
Use: "extend-revocation-list",
Short: "Extend an existing revocation list with a given key ID",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile)
},
}
var verifyRevocationListCmd = &cobra.Command{
Use: "verify-revocation-list",
Short: "Verify a revocation list signature using the public root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile)
},
}
func init() {
rootCmd.AddCommand(createRevocationListCmd)
rootCmd.AddCommand(extendRevocationListCmd)
rootCmd.AddCommand(verifyRevocationListCmd)
createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for")
extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file")
verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file")
verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file")
if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil {
panic(err)
}
}
func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration)
if err != nil {
return fmt.Errorf("failed to create revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list created successfully")
return nil
}
func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
rl, err := reposign.ParseRevocationList(rlBytes)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
kid, err := reposign.ParseKeyID(keyID)
if err != nil {
return fmt.Errorf("invalid key ID: %w", err)
}
newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration)
if err != nil {
return fmt.Errorf("failed to extend revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list extended successfully")
return nil
}
func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error {
// Read revocation list file
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
// Read signature file
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("failed to read signature file: %w", err)
}
// Read public root key file
pubKeyPEM, err := os.ReadFile(publicRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read public root key file: %w", err)
}
// Parse public root key
publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse public root key: %w", err)
}
// Parse signature
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate revocation list
rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature)
if err != nil {
return fmt.Errorf("failed to validate revocation list: %w", err)
}
// Display results
cmd.Println("✅ Revocation list signature is valid")
cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339))
cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339))
cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked))
if len(rl.Revoked) > 0 {
cmd.Println("\nRevoked Keys:")
for keyID, revokedTime := range rl.Revoked {
cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339))
}
}
return nil
}
func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error {
if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil {
return fmt.Errorf("failed to write revocation list file: %w", err)
}
if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil {
return fmt.Errorf("failed to write signature file: %w", err)
}
return nil
}

View File

@@ -0,0 +1,74 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
privKeyFile string
pubKeyFile string
rootExpiration time.Duration
)
var createRootKeyCmd = &cobra.Command{
Use: "create-root-key",
Short: "Create a new root key pair",
Long: `Create a new root key pair and specify an expiration time for it.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
// Validate expiration
if rootExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
// Run main logic
if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil {
return fmt.Errorf("failed to generate root key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createRootKeyCmd)
createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file")
createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file")
createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)")
if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(err)
}
}
func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error {
rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration)
if err != nil {
return fmt.Errorf("generate root key: %w", err)
}
// Write private key
if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", privKeyFile, err)
}
// Write public key
if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err)
}
cmd.Printf("%s\n\n", rk.String())
cmd.Printf("✅ Root key pair generated successfully.\n")
return nil
}

View File

@@ -634,7 +634,11 @@ func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward
return err
}
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
if err := validateDestinationPort(remoteAddr); err != nil {
return fmt.Errorf("invalid remote address: %w", err)
}
log.Debugf("Local port forwarding: %s -> %s", localAddr, remoteAddr)
go func() {
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
@@ -652,7 +656,11 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return err
}
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
if err := validateDestinationPort(localAddr); err != nil {
return fmt.Errorf("invalid local address: %w", err)
}
log.Debugf("Remote port forwarding: %s -> %s", remoteAddr, localAddr)
go func() {
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
@@ -663,6 +671,35 @@ func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forwar
return nil
}
// validateDestinationPort checks that the destination address has a valid port.
// Port 0 is only valid for bind addresses (where the OS picks an available port),
// not for destination addresses where we need to connect.
func validateDestinationPort(addr string) error {
if strings.HasPrefix(addr, "/") || strings.HasPrefix(addr, "./") {
return nil
}
_, portStr, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("parse address %s: %w", addr, err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port %s: %w", portStr, err)
}
if port == 0 {
return fmt.Errorf("port 0 is not valid for destination address")
}
if port < 0 || port > 65535 {
return fmt.Errorf("port %d out of range (1-65535)", port)
}
return nil
}
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
func parsePortForwardSpec(spec string) (string, string, error) {

View File

@@ -127,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)

13
client/cmd/update.go Normal file
View File

@@ -0,0 +1,13 @@
//go:build !windows && !darwin
package cmd
import (
"github.com/spf13/cobra"
)
var updateCmd *cobra.Command
func isUpdateBinary() bool {
return false
}

View File

@@ -0,0 +1,75 @@
//go:build windows || darwin
package cmd
import (
"context"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/util"
)
var (
updateCmd = &cobra.Command{
Use: "update",
Short: "Update the NetBird client application",
RunE: updateFunc,
}
tempDirFlag string
installerFile string
serviceDirFlag string
dryRunFlag bool
)
func init() {
updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir")
updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file")
updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory")
updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes")
}
// isUpdateBinary checks if the current executable is named "update" or "update.exe"
func isUpdateBinary() bool {
// Remove extension for cross-platform compatibility
execPath, err := os.Executable()
if err != nil {
return false
}
baseName := filepath.Base(execPath)
name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
return name == installer.UpdaterBinaryNameWithoutExtension()
}
func updateFunc(cmd *cobra.Command, args []string) error {
if err := setupLogToFile(tempDirFlag); err != nil {
return err
}
log.Infof("updater started: %s", serviceDirFlag)
updater := installer.NewWithDir(tempDirFlag)
if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil {
log.Errorf("failed to update application: %v", err)
return err
}
return nil
}
func setupLogToFile(dir string) error {
logFile := filepath.Join(dir, installer.LogFile)
if _, err := os.Stat(logFile); err == nil {
if err := os.Remove(logFile); err != nil {
log.Errorf("failed to remove existing log file: %v\n", err)
}
}
return util.InitLog(logLevel, util.LogConsole, logFile)
}

View File

@@ -173,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder)
client := internal.NewConnectClient(ctx, c.config, recorder, false)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available

View File

@@ -386,6 +386,97 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
const octet2Count = 25
const octet3Count = 255
prefixes := make([]netip.Prefix, 0, (octet2Count-1)*(octet3Count-1))
for i := 1; i < octet2Count; i++ {
for j := 1; j < octet3Count; j++ {
addr := netip.AddrFrom4([4]byte{192, byte(j), byte(i), 0})
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
}
}
_, err = manager.AddRouteFiltering(
nil,
prefixes,
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
if _, err := exec.LookPath("iptables-save"); err != nil {
t.Skipf("iptables-save not available on this system: %v", err)
}
// First ensure iptables-nft tables exist by running iptables-save
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
})
_, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{},
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
}
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
t.Helper()
require.Equal(t, len(got), len(want), "expression count mismatch")

View File

@@ -48,9 +48,11 @@ const (
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
const refreshRulesMapError = "refresh rules map: %w"
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500
refreshRulesMapError = "refresh rules map: %w"
)
var (
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
@@ -513,16 +515,35 @@ func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, err
}
elements := convertPrefixesToSet(prefixes)
if err := r.conn.AddSet(nfset, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
nElements := len(elements)
maxElements := maxPrefixesSet * 2
initialElements := elements[:min(maxElements, nElements)]
if err := r.conn.AddSet(nfset, initialElements); err != nil {
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
var subEnd int
for subStart := maxElements; subStart < nElements; subStart += maxElements {
subEnd = min(subStart+maxElements, nElements)
subElement := elements[subStart:subEnd]
nSubPrefixes := len(subElement) / 2
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
}
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
return nfset, nil
}

View File

@@ -4,6 +4,7 @@
package device
import (
"fmt"
"os"
log "github.com/sirupsen/logrus"
@@ -45,10 +46,31 @@ func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu
}
}
// ErrInvalidTunnelFD is returned when the tunnel file descriptor is invalid (0).
// This typically means the Swift code couldn't find the utun control socket.
var ErrInvalidTunnelFD = fmt.Errorf("invalid tunnel file descriptor: fd is 0 (Swift failed to locate utun socket)")
func (t *TunDevice) Create() (WGConfigurer, error) {
log.Infof("create tun interface")
dupTunFd, err := unix.Dup(t.tunFd)
var tunDevice tun.Device
var err error
// Validate the tunnel file descriptor.
// On iOS/tvOS, the FD must be provided by the NEPacketTunnelProvider.
// A value of 0 means the Swift code couldn't find the utun control socket
// (the low-level APIs like ctl_info, sockaddr_ctl may not be exposed in
// tvOS SDK headers). This is a hard error - there's no viable fallback
// since tun.CreateTUN() cannot work within the iOS/tvOS sandbox.
if t.tunFd == 0 {
log.Errorf("Tunnel file descriptor is 0 - Swift code failed to locate the utun control socket. " +
"On tvOS, ensure the NEPacketTunnelProvider is properly configured and the tunnel is started.")
return nil, ErrInvalidTunnelFD
}
// Normal iOS/tvOS path: use the provided file descriptor from NEPacketTunnelProvider
var dupTunFd int
dupTunFd, err = unix.Dup(t.tunFd)
if err != nil {
log.Errorf("Unable to dup tun fd: %v", err)
return nil, err
@@ -60,7 +82,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
_ = unix.Close(dupTunFd)
return nil, err
}
tunDevice, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
tunDevice, err = tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
if err != nil {
log.Errorf("Unable to create new tun device from fd: %v", err)
_ = unix.Close(dupTunFd)

View File

@@ -3,12 +3,19 @@
package wgproxy
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
const (
envDisableEBPFWGProxy = "NB_DISABLE_EBPF_WG_PROXY"
)
type KernelFactory struct {
wgPort int
mtu uint16
@@ -22,6 +29,12 @@ func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
mtu: mtu,
}
if isEBPFDisabled() {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Infof("eBPF WireGuard proxy is disabled via %s environment variable", envDisableEBPFWGProxy)
return f
}
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, mtu)
if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
@@ -47,3 +60,16 @@ func (w *KernelFactory) Free() error {
}
return w.ebpfProxy.Free()
}
func isEBPFDisabled() bool {
val := os.Getenv(envDisableEBPFWGProxy)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableEBPFWGProxy, err)
return false
}
return disabled
}

View File

@@ -24,10 +24,14 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -39,11 +43,13 @@ import (
)
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
doInitialAutoUpdate bool
engine *Engine
engineMutex sync.Mutex
persistSyncResponse bool
}
@@ -52,13 +58,15 @@ func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
doInitialAutoUpdate: doInitalAutoUpdate,
engineMutex: sync.Mutex{},
}
}
@@ -162,6 +170,33 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return err
}
var path string
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
// On mobile, use the provided state file path directly
if !fileExists(mobileDependency.StateFilePath) {
if err := createFile(mobileDependency.StateFilePath); err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
path = mobileDependency.StateFilePath
} else {
sm := profilemanager.NewServiceManager("")
path = sm.GetStatePath()
}
stateManager := statemanager.New(path)
stateManager.RegisterState(&sshconfig.ShutdownState{})
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
if err == nil {
updateManager.CheckUpdateSuccess(c.ctx)
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
}
}
defer c.statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
@@ -273,7 +308,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
@@ -283,6 +318,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
// AutoUpdate will be true when the user click on "Connect" menu on the UI
if c.doInitialAutoUpdate {
log.Infof("start engine by ui, run auto-update check")
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
c.doInitialAutoUpdate = false
}
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)

View File

@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
@@ -56,6 +57,7 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
Anonymization Process
@@ -109,6 +111,9 @@ go tool pprof -http=:8088 heap.prof
This will open a web browser tab with the profiling information.
Stack Trace
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
Routes
The routes.txt file contains detailed routing table information in a tabular format:
@@ -327,6 +332,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
if err := g.addSyncResponse(); err != nil {
return fmt.Errorf("add sync response: %w", err)
}
@@ -354,6 +363,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add systemd logs: %v", err)
}
if err := g.addUpdateLogs(); err != nil {
log.Errorf("failed to add updater logs: %v", err)
}
return nil
}
@@ -522,6 +535,18 @@ func (g *BundleGenerator) addProf() (err error) {
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
stackTrace := bytes.NewReader(buf[:n])
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
return fmt.Errorf("add stack trace file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addInterfaces() error {
interfaces, err := net.Interfaces()
if err != nil {
@@ -630,6 +655,29 @@ func (g *BundleGenerator) addStateFile() error {
return nil
}
func (g *BundleGenerator) addUpdateLogs() error {
inst := installer.New()
logFiles := inst.LogFiles()
if len(logFiles) == 0 {
return nil
}
log.Infof("adding updater logs")
for _, logFile := range logFiles {
data, err := os.ReadFile(logFile)
if err != nil {
log.Warnf("failed to read update log file %s: %v", logFile, err)
continue
}
baseName := filepath.Base(logFile)
if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil {
return fmt.Errorf("add update log file %s to zip: %w", baseName, err)
}
}
return nil
}
func (g *BundleGenerator) addCorruptedStateFiles() error {
sm := profilemanager.NewServiceManager("")
pattern := sm.GetStatePath()

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"net/netip"
"net/url"
"strings"
"sync"
@@ -26,6 +27,11 @@ type Resolver struct {
mutex sync.RWMutex
}
type ipsResponse struct {
ips []netip.Addr
err error
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
@@ -99,9 +105,9 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
ips, err := lookupIPWithExtraTimeout(ctx, d)
if err != nil {
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
return err
}
var aRecords, aaaaRecords []dns.RR
@@ -159,6 +165,36 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
return nil
}
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
resultChan := make(chan *ipsResponse, 1)
go func() {
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
resultChan <- &ipsResponse{
err: err,
ips: ips,
}
}()
var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
}
if resp.err != nil {
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
}
return resp.ips, nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
if mgmtURL == nil {

View File

@@ -80,6 +80,7 @@ type DefaultServer struct {
updateSerial uint64
previousConfigHash uint64
currentConfig HostDNSConfig
currentConfigHash uint64
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
@@ -207,6 +208,7 @@ func newDefaultServer(
hostsDNSHolder: newHostsDNSHolder(),
hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
}
// register with root zone, handler chain takes care of the routing
@@ -586,8 +588,29 @@ func (s *DefaultServer) applyHostConfig() {
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
if err != nil {
log.Warnf("unable to hash the host dns configuration, will apply config anyway: %s", err)
// Fall through to apply config anyway (fail-safe approach)
} else if s.currentConfigHash == hash {
log.Debugf("not applying host config as there are no changes")
return
}
log.Debugf("applying host config as there are changes")
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err)
return
}
// Only update hash if it was computed successfully and config was applied
if err == nil {
s.currentConfigHash = hash
}
s.registerFallback(config)

View File

@@ -1602,7 +1602,10 @@ func TestExtraDomains(t *testing.T) {
"other.example.com.",
"duplicate.example.com.",
},
applyHostConfigCall: 4,
// Expect 3 calls instead of 4 because when deregistering duplicate.example.com,
// the domain remains in the config (ref count goes from 2 to 1), so the host
// config hash doesn't change and applyDNSConfig is not called.
applyHostConfigCall: 3,
},
{
name: "Config update with new domains after registration",
@@ -1657,7 +1660,10 @@ func TestExtraDomains(t *testing.T) {
expectedMatchOnly: []string{
"extra.example.com.",
},
applyHostConfigCall: 3,
// Expect 2 calls instead of 3 because when deregistering protected.example.com,
// it's removed from extraDomains but still remains in the config (from customZones),
// so the host config hash doesn't change and applyDNSConfig is not called.
applyHostConfigCall: 2,
},
{
name: "Register domain that is part of nameserver group",

View File

@@ -42,14 +42,13 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager"
cProto "github.com/netbirdio/netbird/client/proto"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -73,6 +72,7 @@ const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
disableAutoUpdate = "disabled"
)
var ErrResetConnection = fmt.Errorf("reset connection")
@@ -201,6 +201,9 @@ type Engine struct {
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
// auto-update
updateManager *updatemanager.Manager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
@@ -221,17 +224,7 @@ type localIpUpdater interface {
}
// NewEngine creates a new Connection Engine with probes attached
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
@@ -247,28 +240,12 @@ func NewEngine(
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: statusRecorder,
stateManager: stateManager,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
}
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
if !fileExists(mobileDep.StateFilePath) {
err := createFile(mobileDep.StateFilePath)
if err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
path = mobileDep.StateFilePath
}
engine.stateManager = statemanager.New(path)
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine
}
@@ -308,6 +285,10 @@ func (e *Engine) Stop() error {
e.srWatcher.Close()
}
if e.updateManager != nil {
e.updateManager.Stop()
}
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
@@ -541,6 +522,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return nil
}
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.handleAutoUpdateVersion(autoUpdateSettings, true)
}
func (e *Engine) createFirewall() error {
if e.config.DisableFirewall {
log.Infof("firewall is disabled")
@@ -749,6 +737,41 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
return nil
}
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
if autoUpdateSettings == nil {
return
}
disabled := autoUpdateSettings.Version == disableAutoUpdate
// Stop and cleanup if disabled
if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop()
e.updateManager = nil
return
}
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
return
}
// Start manager if needed
if e.updateManager == nil {
log.Infof("starting auto-update manager")
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
if err != nil {
return
}
e.updateManager = updateManager
e.updateManager.Start(e.ctx)
}
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
e.updateManager.SetVersion(autoUpdateSettings.Version)
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -758,6 +781,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return e.ctx.Err()
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
@@ -1094,6 +1121,15 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.updateOfflinePeers(networkMap.GetOfflinePeers())
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
for _, p := range networkMap.GetRemotePeers() {
if p.GetWgPubKey() != localPubKey {
remotePeers = append(remotePeers, p)
}
}
// cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers()
@@ -1102,32 +1138,34 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return err
}
} else {
err := e.removePeers(networkMap.GetRemotePeers())
err := e.removePeers(remotePeers)
if err != nil {
return err
}
err = e.modifyPeers(networkMap.GetRemotePeers())
err = e.modifyPeers(remotePeers)
if err != nil {
return err
}
err = e.addNewPeers(networkMap.GetRemotePeers())
err = e.addNewPeers(remotePeers)
if err != nil {
return err
}
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers())
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
e.networkSerial = serial

View File

@@ -11,15 +11,18 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetStatus() (bool, []sshserver.SessionInfo)
UpdateSSHAuth(config *sshauth.Config)
}
func (e *Engine) setupSSHPortRedirection() error {
@@ -353,3 +356,38 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio
return sshServer.GetStatus()
}
// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server
func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) {
if sshAuth == nil {
return
}
if e.sshServer == nil {
return
}
protoUsers := sshAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range sshAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
// Update SSH server with new authorization configuration
authConfig := &sshauth.Config{
UserIDClaim: sshAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
}
e.sshServer.UpdateSSHAuth(authConfig)
}

View File

@@ -253,6 +253,7 @@ func TestEngine_SSH(t *testing.T) {
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil,
nil,
)
engine.dnsServer = &dns.MockServer{
@@ -414,21 +415,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
relayMgr,
&EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
wgIface := &MockWGIface{
NameFunc: func() string { return "utun102" },
@@ -647,7 +640,7 @@ func TestEngine_Sync(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
@@ -812,7 +805,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
@@ -1014,7 +1007,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
@@ -1540,7 +1533,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
e.ctx = ctx
return e, err
}
@@ -1638,7 +1631,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}

View File

@@ -110,7 +110,6 @@ func wakeUpListen(ctx context.Context) {
}
if newHash == initialHash {
log.Tracef("no wakeup detected")
continue
}

View File

@@ -148,13 +148,15 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open(engineCtx context.Context) error {
conn.semaphore.Add(engineCtx)
if err := conn.semaphore.Add(engineCtx); err != nil {
return err
}
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.opened {
conn.semaphore.Done(engineCtx)
conn.semaphore.Done()
return nil
}
@@ -165,6 +167,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
conn.semaphore.Done()
return err
}
conn.workerICE = workerICE
@@ -200,7 +203,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done(conn.ctx)
conn.semaphore.Done()
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()

View File

@@ -20,7 +20,7 @@ type EndpointUpdater struct {
wgConfig WgConfig
initiator bool
// mu protects updateWireGuardPeer and cancelFunc
// mu protects cancelFunc
mu sync.Mutex
cancelFunc func()
updateWg sync.WaitGroup
@@ -86,11 +86,9 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U
case <-ctx.Done():
return
case <-t.C:
e.mu.Lock()
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
e.mu.Unlock()
}
}

View File

@@ -3,9 +3,11 @@ package profilemanager
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/url"
"os"
"os/user"
"path/filepath"
"reflect"
"runtime"
@@ -165,19 +167,26 @@ func getConfigDir() (string, error) {
if ConfigDirOverride != "" {
return ConfigDirOverride, nil
}
configDir, err := os.UserConfigDir()
base, err := baseConfigDir()
if err != nil {
return "", err
}
configDir = filepath.Join(configDir, "netbird")
if _, err := os.Stat(configDir); os.IsNotExist(err) {
if err := os.MkdirAll(configDir, 0755); err != nil {
return "", err
configDir := filepath.Join(base, "netbird")
if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err
}
return configDir, nil
}
func baseConfigDir() (string, error) {
if runtime.GOOS == "darwin" {
if u, err := user.Current(); err == nil && u.HomeDir != "" {
return filepath.Join(u.HomeDir, "Library", "Application Support"), nil
}
}
return configDir, nil
return os.UserConfigDir()
}
func getConfigDirForUser(username string) (string, error) {
@@ -676,7 +685,7 @@ func update(input ConfigInput) (*Config, error) {
return config, nil
}
// GetConfig read config file and return with Config. Errors out if it does not exist
// GetConfig read config file and return with Config and if it was created. Errors out if it does not exist
func GetConfig(configPath string) (*Config, error) {
return readConfig(configPath, false)
}
@@ -812,3 +821,85 @@ func readConfig(configPath string, createIfMissing bool) (*Config, error) {
func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(context.Background(), path, config)
}
// DirectWriteOutConfig writes config directly without atomic temp file operations.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectWriteOutConfig(path string, config *Config) error {
return util.DirectWriteJson(context.Background(), path, config)
}
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !fileExists(input.ConfigPath) {
log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input)
if err != nil {
return nil, err
}
err = util.DirectWriteJson(context.Background(), input.ConfigPath, cfg)
return cfg, err
}
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
// Enforce permissions on existing config files (same as UpdateOrCreateConfig)
if err := util.EnforcePermission(input.ConfigPath); err != nil {
log.Errorf("failed to enforce permission on config file: %v", err)
}
return directUpdate(input)
}
func directUpdate(input ConfigInput) (*Config, error) {
config := &Config{}
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
return nil, err
}
updated, err := config.apply(input)
if err != nil {
return nil, err
}
if updated {
if err := util.DirectWriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err
}
}
return config, nil
}
// ConfigToJSON serializes a Config struct to a JSON string.
// This is useful for exporting config to alternative storage mechanisms
// (e.g., UserDefaults on tvOS where file writes are blocked).
func ConfigToJSON(config *Config) (string, error) {
bs, err := json.MarshalIndent(config, "", " ")
if err != nil {
return "", err
}
return string(bs), nil
}
// ConfigFromJSON deserializes a JSON string to a Config struct.
// This is useful for restoring config from alternative storage mechanisms.
// After unmarshaling, defaults are applied to ensure the config is fully initialized.
func ConfigFromJSON(jsonStr string) (*Config, error) {
config := &Config{}
err := json.Unmarshal([]byte(jsonStr), config)
if err != nil {
return nil, err
}
// Apply defaults to ensure required fields are initialized.
// This mirrors what readConfig does after loading from file.
if _, err := config.apply(ConfigInput{}); err != nil {
return nil, fmt.Errorf("failed to apply defaults to config: %w", err)
}
return config, nil
}

View File

@@ -76,6 +76,7 @@ func (a *ActiveProfileState) FilePath() (string, error) {
}
type ServiceManager struct {
profilesDir string // If set, overrides ConfigDirOverride for profile operations
}
func NewServiceManager(defaultConfigPath string) *ServiceManager {
@@ -85,6 +86,17 @@ func NewServiceManager(defaultConfigPath string) *ServiceManager {
return &ServiceManager{}
}
// NewServiceManagerWithProfilesDir creates a ServiceManager with a specific profiles directory
// This allows setting the profiles directory without modifying the global ConfigDirOverride
func NewServiceManagerWithProfilesDir(defaultConfigPath string, profilesDir string) *ServiceManager {
if defaultConfigPath != "" {
DefaultConfigPath = defaultConfigPath
}
return &ServiceManager{
profilesDir: profilesDir,
}
}
func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil {
@@ -114,14 +126,6 @@ func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
log.Warnf("failed to set permissions for default profile: %v", err)
}
if err := s.SetActiveProfileState(&ActiveProfileState{
Name: "default",
Username: "",
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return false, fmt.Errorf("failed to set active profile state: %w", err)
}
return true, nil
}
@@ -240,7 +244,7 @@ func (s *ServiceManager) DefaultProfilePath() string {
}
func (s *ServiceManager) AddProfile(profileName, username string) error {
configDir, err := getConfigDirForUser(username)
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
}
@@ -270,7 +274,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
}
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
configDir, err := getConfigDirForUser(username)
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
}
@@ -302,7 +306,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
}
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
configDir, err := getConfigDirForUser(username)
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
@@ -361,7 +365,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
configDir, err := getConfigDirForUser(activeProf.Username)
configDir, err := s.getConfigDir(activeProf.Username)
if err != nil {
log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err)
return defaultStatePath
@@ -369,3 +373,12 @@ func (s *ServiceManager) GetStatePath() string {
return filepath.Join(configDir, activeProf.Name+".state.json")
}
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
func (s *ServiceManager) getConfigDir(username string) (string, error) {
if s.profilesDir != "" {
return s.profilesDir, nil
}
return getConfigDirForUser(username)
}

View File

@@ -0,0 +1,35 @@
// Package updatemanager provides automatic update management for the NetBird client.
// It monitors for new versions, handles update triggers from management server directives,
// and orchestrates the download and installation of client updates.
//
// # Overview
//
// The update manager operates as a background service that continuously monitors for
// available updates and automatically initiates the update process when conditions are met.
// It integrates with the installer package to perform the actual installation.
//
// # Update Flow
//
// The complete update process follows these steps:
//
// 1. Manager receives update directive via SetVersion() or detects new version
// 2. Manager validates update should proceed (version comparison, rate limiting)
// 3. Manager publishes "updating" event to status recorder
// 4. Manager persists UpdateState to track update attempt
// 5. Manager downloads installer file (.msi or .exe) to temporary directory
// 6. Manager triggers installation via installer.RunInstallation()
// 7. Installer package handles the actual installation process
// 8. On next startup, CheckUpdateSuccess() verifies update completion
// 9. Manager publishes success/failure event to status recorder
// 10. Manager cleans up UpdateState
//
// # State Management
//
// Update state is persisted across restarts to track update attempts:
//
// - PreUpdateVersion: Version before update attempt
// - TargetVersion: Version attempting to update to
//
// This enables verification of successful updates and appropriate user notification
// after the client restarts with the new version.
package updatemanager

View File

@@ -0,0 +1,138 @@
package downloader
import (
"context"
"fmt"
"io"
"net/http"
"os"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
)
const (
userAgent = "NetBird agent installer/%s"
DefaultRetryDelay = 3 * time.Second
)
func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error {
log.Debugf("starting download from %s", url)
out, err := os.Create(dstFile)
if err != nil {
return fmt.Errorf("failed to create destination file %q: %w", dstFile, err)
}
defer func() {
if cerr := out.Close(); cerr != nil {
log.Warnf("error closing file %q: %v", dstFile, cerr)
}
}()
// First attempt
err = downloadToFileOnce(ctx, url, out)
if err == nil {
log.Infof("successfully downloaded file to %s", dstFile)
return nil
}
// If retryDelay is 0, don't retry
if retryDelay == 0 {
return err
}
log.Warnf("download failed, retrying after %v: %v", retryDelay, err)
// Sleep before retry
if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil {
return fmt.Errorf("download cancelled during retry delay: %w", sleepErr)
}
// Truncate file before retry
if err := out.Truncate(0); err != nil {
return fmt.Errorf("failed to truncate file on retry: %w", err)
}
if _, err := out.Seek(0, 0); err != nil {
return fmt.Errorf("failed to seek to beginning of file: %w", err)
}
// Second attempt
if err := downloadToFileOnce(ctx, url, out); err != nil {
return fmt.Errorf("download failed after retry: %w", err)
}
log.Infof("successfully downloaded file to %s", dstFile)
return nil
}
func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
// Add User-Agent header
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to perform HTTP request: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
log.Warnf("error closing response body: %v", cerr)
}
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, limit))
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return data, nil
}
func downloadToFileOnce(ctx context.Context, url string, out *os.File) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("failed to create HTTP request: %w", err)
}
// Add User-Agent header
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to perform HTTP request: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
log.Warnf("error closing response body: %v", cerr)
}
}()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
}
if _, err := io.Copy(out, resp.Body); err != nil {
return fmt.Errorf("failed to write response body to file: %w", err)
}
return nil
}
func sleepWithContext(ctx context.Context, duration time.Duration) error {
select {
case <-time.After(duration):
return nil
case <-ctx.Done():
return ctx.Err()
}
}

View File

@@ -0,0 +1,199 @@
package downloader
import (
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
)
const (
retryDelay = 100 * time.Millisecond
)
func TestDownloadToFile_Success(t *testing.T) {
// Create a test server that responds successfully
content := "test file content"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(content))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file
err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Verify the file content
data, err := os.ReadFile(dstFile)
if err != nil {
t.Fatalf("failed to read downloaded file: %v", err)
}
if string(data) != content {
t.Errorf("expected content %q, got %q", content, string(data))
}
}
func TestDownloadToFile_SuccessAfterRetry(t *testing.T) {
content := "test file content after retry"
var attemptCount atomic.Int32
// Create a test server that fails on first attempt, succeeds on second
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := attemptCount.Add(1)
if attempt == 1 {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(content))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file (should succeed after retry)
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil {
t.Fatalf("expected no error after retry, got: %v", err)
}
// Verify the file content
data, err := os.ReadFile(dstFile)
if err != nil {
t.Fatalf("failed to read downloaded file: %v", err)
}
if string(data) != content {
t.Errorf("expected content %q, got %q", content, string(data))
}
// Verify it took 2 attempts
if attemptCount.Load() != 2 {
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_FailsAfterRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file (should fail after retry)
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil {
t.Fatal("expected error after retry, got nil")
}
// Verify it tried 2 times
if attemptCount.Load() != 2 {
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Create a context that will be cancelled during retry delay
ctx, cancel := context.WithCancel(context.Background())
// Cancel after a short delay (during the retry sleep)
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
// Download the file (should fail due to context cancellation during retry)
err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile)
if err == nil {
t.Fatal("expected error due to context cancellation, got nil")
}
// Should have only made 1 attempt (cancelled during retry delay)
if attemptCount.Load() != 1 {
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_InvalidURL(t *testing.T) {
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile)
if err == nil {
t.Fatal("expected error for invalid URL, got nil")
}
}
func TestDownloadToFile_InvalidDestination(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("test"))
}))
defer server.Close()
// Use an invalid destination path
err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt")
if err == nil {
t.Fatal("expected error for invalid destination, got nil")
}
}
func TestDownloadToFile_NoRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file with retryDelay = 0 (should not retry)
if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil {
t.Fatal("expected error, got nil")
}
// Verify it only made 1 attempt (no retry)
if attemptCount.Load() != 1 {
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
}
}

View File

@@ -0,0 +1,7 @@
//go:build !windows
package installer
func UpdaterBinaryNameWithoutExtension() string {
return updaterBinary
}

View File

@@ -0,0 +1,11 @@
package installer
import (
"path/filepath"
"strings"
)
func UpdaterBinaryNameWithoutExtension() string {
ext := filepath.Ext(updaterBinary)
return strings.TrimSuffix(updaterBinary, ext)
}

View File

@@ -0,0 +1,111 @@
// Package installer provides functionality for managing NetBird application
// updates and installations across Windows, macOS. It handles
// the complete update lifecycle including artifact download, cryptographic verification,
// installation execution, process management, and result reporting.
//
// # Architecture
//
// The installer package uses a two-process architecture to enable self-updates:
//
// 1. Service Process: The main NetBird daemon process that initiates updates
// 2. Updater Process: A detached child process that performs the actual installation
//
// This separation is critical because:
// - The service binary cannot update itself while running
// - The installer (EXE/MSI/PKG) will terminate the service during installation
// - The updater process survives service termination and restarts it after installation
// - Results can be communicated back to the service after it restarts
//
// # Update Flow
//
// Service Process (RunInstallation):
//
// 1. Validates target version format (semver)
// 2. Determines installer type (EXE, MSI, PKG, or Homebrew)
// 3. Downloads installer file from GitHub releases (if applicable)
// 4. Verifies installer signature using reposign package (cryptographic verification in service process before
// launching updater)
// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows)
// 6. Launches updater process with detached mode:
// - --temp-dir: Temporary directory path
// - --service-dir: Service installation directory
// - --installer-file: Path to downloaded installer (if applicable)
// - --dry-run: Optional flag to test without actually installing
// 7. Service process continues running (will be terminated by installer later)
// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion
//
// Updater Process (Setup):
//
// 1. Receives parameters from service via command-line arguments
// 2. Runs installer with appropriate silent/quiet flags:
// - Windows EXE: installer.exe /S
// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log
// - macOS PKG: installer -pkg installer.pkg -target /
// - macOS Homebrew: brew upgrade netbirdio/tap/netbird
// 3. Installer terminates daemon and UI processes
// 4. Installer replaces binaries with new version
// 5. Updater waits for installer to complete
// 6. Updater restarts daemon:
// - Windows: netbird.exe service start
// - macOS/Linux: netbird service start
// 7. Updater restarts UI:
// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser
// - macOS: Uses launchctl asuser to launch NetBird.app for console user
// - Linux: Not implemented (UI typically auto-starts)
// 8. Updater writes result.json with success/error status
// 9. Updater process exits
//
// # Result Communication
//
// The ResultHandler (result.go) manages communication between updater and service:
//
// Result Structure:
//
// type Result struct {
// Success bool // true if installation succeeded
// Error string // error message if Success is false
// ExecutedAt time.Time // when installation completed
// }
//
// Result files are automatically cleaned up after being read.
//
// # File Locations
//
// Temporary Directory (platform-specific):
//
// Windows:
// - Path: %ProgramData%\Netbird\tmp-install
// - Example: C:\ProgramData\Netbird\tmp-install
//
// macOS:
// - Path: /var/lib/netbird/tmp-install
// - Requires root permissions
//
// Files created during installation:
//
// tmp-install/
// installer.log
// updater[.exe] # Copy of service binary
// netbird_installer_*.[exe|msi|pkg] # Downloaded installer
// result.json # Installation result
// msi.log # MSI verbose log (Windows MSI only)
//
// # API Reference
//
// # Cleanup
//
// CleanUpInstallerFiles() removes temporary files after successful installation:
// - Downloaded installer files (*.exe, *.msi, *.pkg)
// - Updater binary copy
// - Does NOT remove result.json (cleaned by ResultHandler after read)
// - Does NOT remove msi.log (kept for debugging)
//
// # Dry-Run Mode
//
// Dry-run mode allows testing the update process without actually installing:
//
// Enable via environment variable:
//
// export NB_AUTO_UPDATE_DRY_RUN=true
// netbird service install-update 0.29.0
package installer

View File

@@ -0,0 +1,50 @@
//go:build !windows && !darwin
package installer
import (
"context"
"fmt"
)
const (
updaterBinary = "updater"
)
type Installer struct {
tempDir string
}
// New used by the service
func New() *Installer {
return &Installer{}
}
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
func NewWithDir(tempDir string) *Installer {
return &Installer{
tempDir: tempDir,
}
}
func (u *Installer) TempDir() string {
return ""
}
func (c *Installer) LogFiles() []string {
return []string{}
}
func (u *Installer) CleanUpInstallerFiles() error {
return nil
}
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error {
return fmt.Errorf("unsupported platform")
}
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) {
return fmt.Errorf("unsupported platform")
}

View File

@@ -0,0 +1,293 @@
//go:build windows || darwin
package installer
import (
"context"
"fmt"
"io"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"github.com/hashicorp/go-multierror"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
type Installer struct {
tempDir string
}
// New used by the service
func New() *Installer {
return &Installer{
tempDir: defaultTempDir,
}
}
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
func NewWithDir(tempDir string) *Installer {
return &Installer{
tempDir: tempDir,
}
}
// RunInstallation starts the updater process to run the installation
// This will run by the original service process
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) {
resultHandler := NewResultHandler(u.tempDir)
defer func() {
if err != nil {
if writeErr := resultHandler.WriteErr(err); writeErr != nil {
log.Errorf("failed to write error result: %v", writeErr)
}
}
}()
if err := validateTargetVersion(targetVersion); err != nil {
return err
}
if err := u.mkTempDir(); err != nil {
return err
}
var installerFile string
// Download files only when not using any third-party store
if installerType := TypeOfInstaller(ctx); installerType.Downloadable() {
log.Infof("download installer")
var err error
installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion)
if err != nil {
log.Errorf("failed to download installer: %v", err)
return err
}
artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL)
if err != nil {
log.Errorf("failed to create artifact verify: %v", err)
return err
}
if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil {
log.Errorf("artifact verification error: %v", err)
return err
}
}
log.Infof("running installer")
updaterPath, err := u.copyUpdater()
if err != nil {
return err
}
// the directory where the service has been installed
workspace, err := getServiceDir()
if err != nil {
return err
}
args := []string{
"--temp-dir", u.tempDir,
"--service-dir", workspace,
}
if isDryRunEnabled() {
args = append(args, "--dry-run=true")
}
if installerFile != "" {
args = append(args, "--installer-file", installerFile)
}
updateCmd := exec.Command(updaterPath, args...)
log.Infof("starting updater process: %s", updateCmd.String())
// Configure the updater to run in a separate session/process group
// so it survives the parent daemon being stopped
setUpdaterProcAttr(updateCmd)
// Start the updater process asynchronously
if err := updateCmd.Start(); err != nil {
return err
}
pid := updateCmd.Process.Pid
log.Infof("updater started with PID %d", pid)
// Release the process so the OS can fully detach it
if err := updateCmd.Process.Release(); err != nil {
log.Warnf("failed to release updater process: %v", err)
}
return nil
}
// CleanUpInstallerFiles
// - the installer file (pkg, exe, msi)
// - the selfcopy updater.exe
func (u *Installer) CleanUpInstallerFiles() error {
// Check if tempDir exists
info, err := os.Stat(u.tempDir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
if !info.IsDir() {
return nil
}
var merr *multierror.Error
if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) {
merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err))
}
entries, err := os.ReadDir(u.tempDir)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
for _, ext := range binaryExtensions {
if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) {
if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err))
}
break
}
}
}
return merr.ErrorOrNil()
}
func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) {
fileURL := urlWithVersionArch(installerType, targetVersion)
// Clean up temp directory on error
var success bool
defer func() {
if !success {
if err := os.RemoveAll(u.tempDir); err != nil {
log.Errorf("error cleaning up temporary directory: %v", err)
}
}
}()
fileName := path.Base(fileURL)
if fileName == "." || fileName == "/" || fileName == "" {
return "", fmt.Errorf("invalid file URL: %s", fileURL)
}
outputFilePath := filepath.Join(u.tempDir, fileName)
if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil {
return "", err
}
success = true
return outputFilePath, nil
}
func (u *Installer) TempDir() string {
return u.tempDir
}
func (u *Installer) mkTempDir() error {
if err := os.MkdirAll(u.tempDir, 0o755); err != nil {
log.Debugf("failed to create tempdir: %s", u.tempDir)
return err
}
return nil
}
func (u *Installer) copyUpdater() (string, error) {
src, err := getServiceBinary()
if err != nil {
return "", fmt.Errorf("failed to get updater binary: %w", err)
}
dst := filepath.Join(u.tempDir, updaterBinary)
if err := copyFile(src, dst); err != nil {
return "", fmt.Errorf("failed to copy updater binary: %w", err)
}
if err := os.Chmod(dst, 0o755); err != nil {
return "", fmt.Errorf("failed to set permissions: %w", err)
}
return dst, nil
}
func validateTargetVersion(targetVersion string) error {
if targetVersion == "" {
return fmt.Errorf("target version cannot be empty")
}
_, err := goversion.NewVersion(targetVersion)
if err != nil {
return fmt.Errorf("invalid target version %q: %w", targetVersion, err)
}
return nil
}
func copyFile(src, dst string) error {
log.Infof("copying %s to %s", src, dst)
in, err := os.Open(src)
if err != nil {
return fmt.Errorf("open source: %w", err)
}
defer func() {
if err := in.Close(); err != nil {
log.Warnf("failed to close source file: %v", err)
}
}()
out, err := os.Create(dst)
if err != nil {
return fmt.Errorf("create destination: %w", err)
}
defer func() {
if err := out.Close(); err != nil {
log.Warnf("failed to close destination file: %v", err)
}
}()
if _, err := io.Copy(out, in); err != nil {
return fmt.Errorf("copy: %w", err)
}
return nil
}
func getServiceDir() (string, error) {
exePath, err := os.Executable()
if err != nil {
return "", err
}
return filepath.Dir(exePath), nil
}
func getServiceBinary() (string, error) {
return os.Executable()
}
func isDryRunEnabled() bool {
return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true")
}

View File

@@ -0,0 +1,11 @@
package installer
import (
"path/filepath"
)
func (u *Installer) LogFiles() []string {
return []string{
filepath.Join(u.tempDir, LogFile),
}
}

View File

@@ -0,0 +1,12 @@
package installer
import (
"path/filepath"
)
func (u *Installer) LogFiles() []string {
return []string{
filepath.Join(u.tempDir, msiLogFile),
filepath.Join(u.tempDir, LogFile),
}
}

View File

@@ -0,0 +1,238 @@
package installer
import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
log "github.com/sirupsen/logrus"
)
const (
daemonName = "netbird"
updaterBinary = "updater"
uiBinary = "/Applications/NetBird.app"
defaultTempDir = "/var/lib/netbird/tmp-install"
pkgDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
)
var (
binaryExtensions = []string{"pkg"}
)
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
resultHandler := NewResultHandler(u.tempDir)
// Always ensure daemon and UI are restarted after setup
defer func() {
log.Infof("write out result")
var err error
if resultErr == nil {
err = resultHandler.WriteSuccess()
} else {
err = resultHandler.WriteErr(resultErr)
}
if err != nil {
log.Errorf("failed to write update result: %v", err)
}
// skip service restart if dry-run mode is enabled
if dryRun {
return
}
log.Infof("starting daemon back")
if err := u.startDaemon(daemonFolder); err != nil {
log.Errorf("failed to start daemon: %v", err)
}
log.Infof("starting UI back")
if err := u.startUIAsUser(); err != nil {
log.Errorf("failed to start UI: %v", err)
}
}()
if dryRun {
time.Sleep(7 * time.Second)
log.Infof("dry-run mode enabled, skipping actual installation")
resultErr = fmt.Errorf("dry-run mode enabled")
return
}
switch TypeOfInstaller(ctx) {
case TypePKG:
resultErr = u.installPkgFile(ctx, installerFile)
case TypeHomebrew:
resultErr = u.updateHomeBrew(ctx)
}
return resultErr
}
func (u *Installer) startDaemon(daemonFolder string) error {
log.Infof("starting netbird service")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
if output, err := cmd.CombinedOutput(); err != nil {
log.Warnf("failed to start netbird service: %v, output: %s", err, string(output))
return err
}
log.Infof("netbird service started successfully")
return nil
}
func (u *Installer) startUIAsUser() error {
log.Infof("starting netbird-ui: %s", uiBinary)
// Get the current console user
cmd := exec.Command("stat", "-f", "%Su", "/dev/console")
output, err := cmd.Output()
if err != nil {
return fmt.Errorf("failed to get console user: %w", err)
}
username := strings.TrimSpace(string(output))
if username == "" || username == "root" {
return fmt.Errorf("no active user session found")
}
log.Infof("starting UI for user: %s", username)
// Get user's UID
userInfo, err := user.Lookup(username)
if err != nil {
return fmt.Errorf("failed to lookup user %s: %w", username, err)
}
// Start the UI process as the console user using launchctl
// This ensures the app runs in the user's context with proper GUI access
launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary)
log.Infof("launchCmd: %s", launchCmd.String())
// Set the user's home directory for proper macOS app behavior
launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir)
log.Infof("set HOME environment variable: %s", userInfo.HomeDir)
if err := launchCmd.Start(); err != nil {
return fmt.Errorf("failed to start UI process: %w", err)
}
// Release the process so it can run independently
if err := launchCmd.Process.Release(); err != nil {
log.Warnf("failed to release UI process: %v", err)
}
log.Infof("netbird-ui started successfully for user %s", username)
return nil
}
func (u *Installer) installPkgFile(ctx context.Context, path string) error {
log.Infof("installing pkg file: %s", path)
// Kill any existing UI processes before installation
// This ensures the postinstall script's "open $APP" will start the new version
u.killUI()
volume := "/"
cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
if err := cmd.Start(); err != nil {
return fmt.Errorf("error running pkg file: %w", err)
}
log.Infof("installer started with PID %d", cmd.Process.Pid)
if err := cmd.Wait(); err != nil {
return fmt.Errorf("error running pkg file: %w", err)
}
log.Infof("pkg file installed successfully")
return nil
}
func (u *Installer) updateHomeBrew(ctx context.Context) error {
log.Infof("updating homebrew")
// Kill any existing UI processes before upgrade
// This ensures the new version will be started after upgrade
u.killUI()
// Homebrew must be run as a non-root user
// To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
// Check both Apple Silicon and Intel Mac paths
brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/"
brewBinPath := "/opt/homebrew/bin/brew"
if _, err := os.Stat(brewTapPath); os.IsNotExist(err) {
// Try Intel Mac path
brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/"
brewBinPath = "/usr/local/bin/brew"
}
fileInfo, err := os.Stat(brewTapPath)
if err != nil {
return fmt.Errorf("error getting homebrew installation path info: %w", err)
}
fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
if !ok {
return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
}
// Get username from UID
brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
if err != nil {
return fmt.Errorf("error looking up brew installer user: %w", err)
}
userName := brewUser.Username
// Get user HOME, required for brew to run correctly
// https://github.com/Homebrew/brew/issues/15833
homeDir := brewUser.HomeDir
// Check if netbird-ui is installed (must run as the brew user, not root)
checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui")
checkUICmd.Env = append(os.Environ(), "HOME="+homeDir)
uiInstalled := checkUICmd.Run() == nil
// Homebrew does not support installing specific versions
// Thus it will always update to latest and ignore targetVersion
upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"}
if uiInstalled {
upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
}
cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...)
cmd.Env = append(os.Environ(), "HOME="+homeDir)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output))
}
log.Infof("homebrew updated successfully")
return nil
}
func (u *Installer) killUI() {
log.Infof("killing existing netbird-ui processes")
cmd := exec.Command("pkill", "-x", "netbird-ui")
if output, err := cmd.CombinedOutput(); err != nil {
// pkill returns exit code 1 if no processes matched, which is fine
log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output))
} else {
log.Infof("netbird-ui processes killed")
}
}
func urlWithVersionArch(_ Type, version string) string {
url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
}

View File

@@ -0,0 +1,213 @@
package installer
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
daemonName = "netbird.exe"
uiName = "netbird-ui.exe"
updaterBinary = "updater.exe"
msiLogFile = "msi.log"
msiDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
exeDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
)
var (
defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install")
// for the cleanup
binaryExtensions = []string{"msi", "exe"}
)
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
resultHandler := NewResultHandler(u.tempDir)
// Always ensure daemon and UI are restarted after setup
defer func() {
log.Infof("starting daemon back")
if err := u.startDaemon(daemonFolder); err != nil {
log.Errorf("failed to start daemon: %v", err)
}
log.Infof("starting UI back")
if err := u.startUIAsUser(daemonFolder); err != nil {
log.Errorf("failed to start UI: %v", err)
}
log.Infof("write out result")
var err error
if resultErr == nil {
err = resultHandler.WriteSuccess()
} else {
err = resultHandler.WriteErr(resultErr)
}
if err != nil {
log.Errorf("failed to write update result: %v", err)
}
}()
if dryRun {
log.Infof("dry-run mode enabled, skipping actual installation")
resultErr = fmt.Errorf("dry-run mode enabled")
return
}
installerType, err := typeByFileExtension(installerFile)
if err != nil {
log.Debugf("%v", err)
resultErr = err
return
}
var cmd *exec.Cmd
switch installerType {
case TypeExe:
log.Infof("run exe installer: %s", installerFile)
cmd = exec.CommandContext(ctx, installerFile, "/S")
default:
installerDir := filepath.Dir(installerFile)
logPath := filepath.Join(installerDir, msiLogFile)
log.Infof("run msi installer: %s", installerFile)
cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath)
}
cmd.Dir = filepath.Dir(installerFile)
if resultErr = cmd.Start(); resultErr != nil {
log.Errorf("error starting installer: %v", resultErr)
return
}
log.Infof("installer started with PID %d", cmd.Process.Pid)
if resultErr = cmd.Wait(); resultErr != nil {
log.Errorf("installer process finished with error: %v", resultErr)
return
}
return nil
}
func (u *Installer) startDaemon(daemonFolder string) error {
log.Infof("starting netbird service")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
if output, err := cmd.CombinedOutput(); err != nil {
log.Debugf("failed to start netbird service: %v, output: %s", err, string(output))
return err
}
log.Infof("netbird service started successfully")
return nil
}
func (u *Installer) startUIAsUser(daemonFolder string) error {
uiPath := filepath.Join(daemonFolder, uiName)
log.Infof("starting netbird-ui: %s", uiPath)
// Get the active console session ID
sessionID := windows.WTSGetActiveConsoleSessionId()
if sessionID == 0xFFFFFFFF {
return fmt.Errorf("no active user session found")
}
// Get the user token for that session
var userToken windows.Token
err := windows.WTSQueryUserToken(sessionID, &userToken)
if err != nil {
return fmt.Errorf("failed to query user token: %w", err)
}
defer func() {
if err := userToken.Close(); err != nil {
log.Warnf("failed to close user token: %v", err)
}
}()
// Duplicate the token to a primary token
var primaryToken windows.Token
err = windows.DuplicateTokenEx(
userToken,
windows.MAXIMUM_ALLOWED,
nil,
windows.SecurityImpersonation,
windows.TokenPrimary,
&primaryToken,
)
if err != nil {
return fmt.Errorf("failed to duplicate token: %w", err)
}
defer func() {
if err := primaryToken.Close(); err != nil {
log.Warnf("failed to close token: %v", err)
}
}()
// Prepare startup info
var si windows.StartupInfo
si.Cb = uint32(unsafe.Sizeof(si))
si.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
var pi windows.ProcessInformation
cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath))
if err != nil {
return fmt.Errorf("failed to convert path to UTF16: %w", err)
}
creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT
err = windows.CreateProcessAsUser(
primaryToken,
nil,
cmdLine,
nil,
nil,
false,
creationFlags,
nil,
nil,
&si,
&pi,
)
if err != nil {
return fmt.Errorf("CreateProcessAsUser failed: %w", err)
}
// Close handles
if err := windows.CloseHandle(pi.Process); err != nil {
log.Warnf("failed to close process handle: %v", err)
}
if err := windows.CloseHandle(pi.Thread); err != nil {
log.Warnf("failed to close thread handle: %v", err)
}
log.Infof("netbird-ui started successfully in session %d", sessionID)
return nil
}
func urlWithVersionArch(it Type, version string) string {
var url string
if it == TypeExe {
url = exeDownloadURL
} else {
url = msiDownloadURL
}
url = strings.ReplaceAll(url, "%version", version)
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
}

View File

@@ -0,0 +1,5 @@
package installer
const (
LogFile = "installer.log"
)

View File

@@ -0,0 +1,15 @@
package installer
import (
"os/exec"
"syscall"
)
// setUpdaterProcAttr configures the updater process to run in a new session,
// making it independent of the parent daemon process. This ensures the updater
// survives when the daemon is stopped during the pkg installation.
func setUpdaterProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
}
}

View File

@@ -0,0 +1,14 @@
package installer
import (
"os/exec"
"syscall"
)
// setUpdaterProcAttr configures the updater process to run detached from the parent,
// making it independent of the parent daemon process.
func setUpdaterProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS
}
}

View File

@@ -0,0 +1,7 @@
//go:build devartifactsign
package installer
const (
DefaultSigningKeysBaseURL = "http://192.168.0.10:9089/signrepo"
)

View File

@@ -0,0 +1,7 @@
//go:build !devartifactsign
package installer
const (
DefaultSigningKeysBaseURL = "https://publickeys.netbird.io/artifact-signatures"
)

View File

@@ -0,0 +1,230 @@
package installer
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus"
)
const (
resultFile = "result.json"
)
type Result struct {
Success bool
Error string
ExecutedAt time.Time
}
// ResultHandler handles reading and writing update results
type ResultHandler struct {
resultFile string
}
// NewResultHandler creates a new communicator with the given directory path
// The result file will be created as "result.json" in the specified directory
func NewResultHandler(installerDir string) *ResultHandler {
// Create it if it doesn't exist
// do not care if already exists
_ = os.MkdirAll(installerDir, 0o700)
rh := &ResultHandler{
resultFile: filepath.Join(installerDir, resultFile),
}
return rh
}
func (rh *ResultHandler) GetErrorResultReason() string {
result, err := rh.tryReadResult()
if err == nil && !result.Success {
return result.Error
}
if err := rh.cleanup(); err != nil {
log.Warnf("failed to cleanup result file: %v", err)
}
return ""
}
func (rh *ResultHandler) WriteSuccess() error {
result := Result{
Success: true,
ExecutedAt: time.Now(),
}
return rh.write(result)
}
func (rh *ResultHandler) WriteErr(errReason error) error {
result := Result{
Success: false,
Error: errReason.Error(),
ExecutedAt: time.Now(),
}
return rh.write(result)
}
func (rh *ResultHandler) Watch(ctx context.Context) (Result, error) {
log.Infof("start watching result: %s", rh.resultFile)
// Check if file already exists (updater finished before we started watching)
if result, err := rh.tryReadResult(); err == nil {
log.Infof("installer result: %v", result)
return result, nil
}
dir := filepath.Dir(rh.resultFile)
if err := rh.waitForDirectory(ctx, dir); err != nil {
return Result{}, err
}
return rh.watchForResultFile(ctx, dir)
}
func (rh *ResultHandler) waitForDirectory(ctx context.Context, dir string) error {
ticker := time.NewTicker(300 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if info, err := os.Stat(dir); err == nil && info.IsDir() {
return nil
}
}
}
}
func (rh *ResultHandler) watchForResultFile(ctx context.Context, dir string) (Result, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Error(err)
return Result{}, err
}
defer func() {
if err := watcher.Close(); err != nil {
log.Warnf("failed to close watcher: %v", err)
}
}()
if err := watcher.Add(dir); err != nil {
return Result{}, fmt.Errorf("failed to watch directory: %v", err)
}
// Check again after setting up watcher to avoid race condition
// (file could have been created between initial check and watcher setup)
if result, err := rh.tryReadResult(); err == nil {
log.Infof("installer result: %v", result)
return result, nil
}
for {
select {
case <-ctx.Done():
return Result{}, ctx.Err()
case event, ok := <-watcher.Events:
if !ok {
return Result{}, errors.New("watcher closed unexpectedly")
}
if result, done := rh.handleWatchEvent(event); done {
return result, nil
}
case err, ok := <-watcher.Errors:
if !ok {
return Result{}, errors.New("watcher closed unexpectedly")
}
return Result{}, fmt.Errorf("watcher error: %w", err)
}
}
}
func (rh *ResultHandler) handleWatchEvent(event fsnotify.Event) (Result, bool) {
if event.Name != rh.resultFile {
return Result{}, false
}
if event.Has(fsnotify.Create) {
result, err := rh.tryReadResult()
if err != nil {
log.Debugf("error while reading result: %v", err)
return result, true
}
log.Infof("installer result: %v", result)
return result, true
}
return Result{}, false
}
// Write writes the update result to a file for the UI to read
func (rh *ResultHandler) write(result Result) error {
log.Infof("write out installer result to: %s", rh.resultFile)
// Ensure directory exists
dir := filepath.Dir(rh.resultFile)
if err := os.MkdirAll(dir, 0o755); err != nil {
log.Errorf("failed to create directory %s: %v", dir, err)
return err
}
data, err := json.Marshal(result)
if err != nil {
return err
}
// Write to a temporary file first, then rename for atomic operation
tmpPath := rh.resultFile + ".tmp"
if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
log.Errorf("failed to create temp file: %s", err)
return err
}
// Atomic rename
if err := os.Rename(tmpPath, rh.resultFile); err != nil {
if cleanupErr := os.Remove(tmpPath); cleanupErr != nil {
log.Warnf("Failed to remove temp result file: %v", err)
}
return err
}
return nil
}
func (rh *ResultHandler) cleanup() error {
err := os.Remove(rh.resultFile)
if err != nil && !os.IsNotExist(err) {
return err
}
log.Debugf("delete installer result file: %s", rh.resultFile)
return nil
}
// tryReadResult attempts to read and validate the result file
func (rh *ResultHandler) tryReadResult() (Result, error) {
data, err := os.ReadFile(rh.resultFile)
if err != nil {
return Result{}, err
}
var result Result
if err := json.Unmarshal(data, &result); err != nil {
return Result{}, fmt.Errorf("invalid result format: %w", err)
}
if err := rh.cleanup(); err != nil {
log.Warnf("failed to cleanup result file: %v", err)
}
return result, nil
}

View File

@@ -0,0 +1,14 @@
package installer
type Type struct {
name string
downloadable bool
}
func (t Type) String() string {
return t.name
}
func (t Type) Downloadable() bool {
return t.downloadable
}

View File

@@ -0,0 +1,22 @@
package installer
import (
"context"
"os/exec"
)
var (
TypeHomebrew = Type{name: "Homebrew", downloadable: false}
TypePKG = Type{name: "pkg", downloadable: true}
)
func TypeOfInstaller(ctx context.Context) Type {
cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client")
_, err := cmd.Output()
if err != nil && cmd.ProcessState.ExitCode() == 1 {
// Not installed using pkg file, thus installed using Homebrew
return TypeHomebrew
}
return TypePKG
}

View File

@@ -0,0 +1,51 @@
package installer
import (
"context"
"fmt"
"strings"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
)
const (
uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
)
var (
TypeExe = Type{name: "EXE", downloadable: true}
TypeMSI = Type{name: "MSI", downloadable: true}
)
func TypeOfInstaller(_ context.Context) Type {
paths := []string{uninstallKeyPath64, uninstallKeyPath32}
for _, path := range paths {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
if err != nil {
continue
}
if err := k.Close(); err != nil {
log.Warnf("Error closing registry key: %v", err)
}
return TypeExe
}
log.Debug("No registry entry found for Netbird, assuming MSI installation")
return TypeMSI
}
func typeByFileExtension(filePath string) (Type, error) {
switch {
case strings.HasSuffix(strings.ToLower(filePath), ".exe"):
return TypeExe, nil
case strings.HasSuffix(strings.ToLower(filePath), ".msi"):
return TypeMSI, nil
default:
return Type{}, fmt.Errorf("unsupported installer type for file: %s", filePath)
}
}

View File

@@ -0,0 +1,374 @@
//go:build windows || darwin
package updatemanager
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
v "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
const (
latestVersion = "latest"
// this version will be ignored
developmentVersion = "development"
)
var errNoUpdateState = errors.New("no update state found")
type UpdateState struct {
PreUpdateVersion string
TargetVersion string
}
func (u UpdateState) Name() string {
return "autoUpdate"
}
type Manager struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
lastTrigger time.Time
mgmUpdateChan chan struct{}
updateChannel chan struct{}
currentVersion string
update UpdateInterface
wg sync.WaitGroup
cancel context.CancelFunc
expectedVersion *v.Version
updateToLatestVersion bool
// updateMutex protect update and expectedVersion fields
updateMutex sync.Mutex
triggerUpdateFn func(context.Context, string) error
}
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
if runtime.GOOS == "darwin" {
isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
if isBrew {
log.Warnf("auto-update disabled on Home Brew installation")
return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
}
}
return newManager(statusRecorder, stateManager)
}
func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
manager := &Manager{
statusRecorder: statusRecorder,
stateManager: stateManager,
mgmUpdateChan: make(chan struct{}, 1),
updateChannel: make(chan struct{}, 1),
currentVersion: version.NetbirdVersion(),
update: version.NewUpdate("nb/client"),
}
manager.triggerUpdateFn = manager.triggerUpdate
stateManager.RegisterState(&UpdateState{})
return manager, nil
}
// CheckUpdateSuccess checks if the update was successful and send a notification.
// It works without to start the update manager.
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
reason := m.lastResultErrReason()
if reason != "" {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
"Auto-update failed",
fmt.Sprintf("Auto-update failed: %s", reason),
nil,
)
}
updateState, err := m.loadAndDeleteUpdateState(ctx)
if err != nil {
if errors.Is(err, errNoUpdateState) {
return
}
log.Errorf("failed to load update state: %v", err)
return
}
log.Debugf("auto-update state loaded, %v", *updateState)
if updateState.TargetVersion == m.currentVersion {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"Auto-update completed",
fmt.Sprintf("Your NetBird Client was auto-updated to version %s", m.currentVersion),
nil,
)
return
}
}
func (m *Manager) Start(ctx context.Context) {
if m.cancel != nil {
log.Errorf("Manager already started")
return
}
m.update.SetDaemonVersion(m.currentVersion)
m.update.SetOnUpdateListener(func() {
select {
case m.updateChannel <- struct{}{}:
default:
}
})
go m.update.StartFetcher()
ctx, cancel := context.WithCancel(ctx)
m.cancel = cancel
m.wg.Add(1)
go m.updateLoop(ctx)
}
func (m *Manager) SetVersion(expectedVersion string) {
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
if m.cancel == nil {
log.Errorf("manager not started")
return
}
m.updateMutex.Lock()
defer m.updateMutex.Unlock()
if expectedVersion == "" {
log.Errorf("empty expected version provided")
m.expectedVersion = nil
m.updateToLatestVersion = false
return
}
if expectedVersion == latestVersion {
m.updateToLatestVersion = true
m.expectedVersion = nil
} else {
expectedSemVer, err := v.NewVersion(expectedVersion)
if err != nil {
log.Errorf("error parsing version: %v", err)
return
}
if m.expectedVersion != nil && m.expectedVersion.Equal(expectedSemVer) {
return
}
m.expectedVersion = expectedSemVer
m.updateToLatestVersion = false
}
select {
case m.mgmUpdateChan <- struct{}{}:
default:
}
}
func (m *Manager) Stop() {
if m.cancel == nil {
return
}
m.cancel()
m.updateMutex.Lock()
if m.update != nil {
m.update.StopWatch()
m.update = nil
}
m.updateMutex.Unlock()
m.wg.Wait()
}
func (m *Manager) onContextCancel() {
if m.cancel == nil {
return
}
m.updateMutex.Lock()
defer m.updateMutex.Unlock()
if m.update != nil {
m.update.StopWatch()
m.update = nil
}
}
func (m *Manager) updateLoop(ctx context.Context) {
defer m.wg.Done()
for {
select {
case <-ctx.Done():
m.onContextCancel()
return
case <-m.mgmUpdateChan:
case <-m.updateChannel:
log.Infof("fetched new version info")
}
m.handleUpdate(ctx)
}
}
func (m *Manager) handleUpdate(ctx context.Context) {
var updateVersion *v.Version
m.updateMutex.Lock()
if m.update == nil {
m.updateMutex.Unlock()
return
}
expectedVersion := m.expectedVersion
useLatest := m.updateToLatestVersion
curLatestVersion := m.update.LatestVersion()
m.updateMutex.Unlock()
switch {
// Resolve "latest" to actual version
case useLatest:
if curLatestVersion == nil {
log.Tracef("latest version not fetched yet")
return
}
updateVersion = curLatestVersion
// Update to specific version
case expectedVersion != nil:
updateVersion = expectedVersion
default:
log.Debugf("no expected version information set")
return
}
log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
if !m.shouldUpdate(updateVersion) {
return
}
m.lastTrigger = time.Now()
log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"Automatically updating client",
"Your client version is older than auto-update version set in Management, updating client now.",
nil,
)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"",
"",
map[string]string{"progress_window": "show", "version": updateVersion.String()},
)
updateState := UpdateState{
PreUpdateVersion: m.currentVersion,
TargetVersion: updateVersion.String(),
}
if err := m.stateManager.UpdateState(updateState); err != nil {
log.Warnf("failed to update state: %v", err)
} else {
if err = m.stateManager.PersistState(ctx); err != nil {
log.Warnf("failed to persist state: %v", err)
}
}
if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
log.Errorf("Error triggering auto-update: %v", err)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
"Auto-update failed",
fmt.Sprintf("Auto-update failed: %v", err),
nil,
)
}
}
// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
// Returns nil if no state exists.
func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, error) {
stateType := &UpdateState{}
m.stateManager.RegisterState(stateType)
if err := m.stateManager.LoadState(stateType); err != nil {
return nil, fmt.Errorf("load state: %w", err)
}
state := m.stateManager.GetState(stateType)
if state == nil {
return nil, errNoUpdateState
}
updateState, ok := state.(*UpdateState)
if !ok {
return nil, fmt.Errorf("failed to cast state to UpdateState")
}
if err := m.stateManager.DeleteState(updateState); err != nil {
return nil, fmt.Errorf("delete state: %w", err)
}
if err := m.stateManager.PersistState(ctx); err != nil {
return nil, fmt.Errorf("persist state: %w", err)
}
return updateState, nil
}
func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
if m.currentVersion == developmentVersion {
log.Debugf("skipping auto-update, running development version")
return false
}
currentVersion, err := v.NewVersion(m.currentVersion)
if err != nil {
log.Errorf("error checking for update, error parsing version `%s`: %v", m.currentVersion, err)
return false
}
if currentVersion.GreaterThanOrEqual(updateVersion) {
log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", m.currentVersion, updateVersion)
return false
}
if time.Since(m.lastTrigger) < 5*time.Minute {
log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
return false
}
return true
}
func (m *Manager) lastResultErrReason() string {
inst := installer.New()
result := installer.NewResultHandler(inst.TempDir())
return result.GetErrorResultReason()
}
func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
inst := installer.New()
return inst.RunInstallation(ctx, targetVersion)
}

View File

@@ -0,0 +1,214 @@
//go:build windows || darwin
package updatemanager
import (
"context"
"fmt"
"path"
"testing"
"time"
v "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type versionUpdateMock struct {
latestVersion *v.Version
onUpdate func()
}
func (v versionUpdateMock) StopWatch() {}
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
return false
}
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
v.onUpdate = updateFn
}
func (v versionUpdateMock) LatestVersion() *v.Version {
return v.latestVersion
}
func (v versionUpdateMock) StartFetcher() {}
func Test_LatestVersion(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
initialLatestVersion *v.Version
latestVersion *v.Version
shouldUpdateInit bool
shouldUpdateLater bool
}{
{
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
daemonVersion: "1.0.0",
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
latestVersion: v.Must(v.NewSemver("1.0.2")),
shouldUpdateInit: true,
shouldUpdateLater: false,
},
{
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
daemonVersion: "1.0.0",
initialLatestVersion: nil,
latestVersion: v.Must(v.NewSemver("1.0.1")),
shouldUpdateInit: false,
shouldUpdateLater: true,
},
}
for idx, c := range testMatrix {
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = mockUpdate
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion("latest")
var triggeredInit bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.initialLatestVersion.String() {
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
}
triggeredInit = true
case <-time.After(10 * time.Millisecond):
triggeredInit = false
}
if triggeredInit != c.shouldUpdateInit {
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
}
mockUpdate.latestVersion = c.latestVersion
mockUpdate.onUpdate()
var triggeredLater bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
}
triggeredLater = true
case <-time.After(10 * time.Millisecond):
triggeredLater = false
}
if triggeredLater != c.shouldUpdateLater {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
}
m.Stop()
}
}
func Test_HandleUpdate(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
latestVersion *v.Version
expectedVersion string
shouldUpdate bool
}{
{
name: "Update to a specific version should update regardless of if latestVersion is available yet",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.56.0",
shouldUpdate: true,
},
{
name: "Update to specific version should not update if version matches",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.55.0",
shouldUpdate: false,
},
{
name: "Update to specific version should not update if current version is newer",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.54.0",
shouldUpdate: false,
},
{
name: "Update to latest version should update if latest is newer",
daemonVersion: "0.55.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: true,
},
{
name: "Update to latest version should not update if latest == current",
daemonVersion: "0.56.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if daemon version is invalid",
daemonVersion: "development",
latestVersion: v.Must(v.NewSemver("1.0.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expecting latest and latest version is unavailable",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expected version is invalid",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "development",
shouldUpdate: false,
},
}
for idx, c := range testMatrix {
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion(c.expectedVersion)
var updateTriggered bool
select {
case targetVersion := <-targetVersionChan:
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
}
updateTriggered = true
case <-time.After(10 * time.Millisecond):
updateTriggered = false
}
if updateTriggered != c.shouldUpdate {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
}
m.Stop()
}
}

View File

@@ -0,0 +1,39 @@
//go:build !windows && !darwin
package updatemanager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Manager is a no-op stub for unsupported platforms
type Manager struct{}
// NewManager returns a no-op manager for unsupported platforms
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
return nil, fmt.Errorf("update manager is not supported on this platform")
}
// CheckUpdateSuccess is a no-op on unsupported platforms
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
// no-op
}
// Start is a no-op on unsupported platforms
func (m *Manager) Start(ctx context.Context) {
// no-op
}
// SetVersion is a no-op on unsupported platforms
func (m *Manager) SetVersion(expectedVersion string) {
// no-op
}
// Stop is a no-op on unsupported platforms
func (m *Manager) Stop() {
// no-op
}

View File

@@ -0,0 +1,302 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"hash"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/blake2s"
)
const (
tagArtifactPrivate = "ARTIFACT PRIVATE KEY"
tagArtifactPublic = "ARTIFACT PUBLIC KEY"
maxArtifactKeySignatureAge = 10 * 365 * 24 * time.Hour
maxArtifactSignatureAge = 10 * 365 * 24 * time.Hour
)
// ArtifactHash wraps a hash.Hash and counts bytes written
type ArtifactHash struct {
hash.Hash
}
// NewArtifactHash returns an initialized ArtifactHash using BLAKE2s
func NewArtifactHash() *ArtifactHash {
h, err := blake2s.New256(nil)
if err != nil {
panic(err) // Should never happen with nil Key
}
return &ArtifactHash{Hash: h}
}
func (ah *ArtifactHash) Write(b []byte) (int, error) {
return ah.Hash.Write(b)
}
// ArtifactKey is a signing Key used to sign artifacts
type ArtifactKey struct {
PrivateKey
}
func (k ArtifactKey) String() string {
return fmt.Sprintf(
"ArtifactKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
k.Metadata.ID,
k.Metadata.CreatedAt.Format(time.RFC3339),
k.Metadata.ExpiresAt.Format(time.RFC3339),
)
}
func GenerateArtifactKey(rootKey *RootKey, expiration time.Duration) (*ArtifactKey, []byte, []byte, []byte, error) {
// Verify root key is still valid
if !rootKey.Metadata.ExpiresAt.IsZero() && time.Now().After(rootKey.Metadata.ExpiresAt) {
return nil, nil, nil, nil, fmt.Errorf("root key has expired on %s", rootKey.Metadata.ExpiresAt.Format(time.RFC3339))
}
now := time.Now()
expirationTime := now.Add(expiration)
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("generate ed25519 key: %w", err)
}
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: now.UTC(),
ExpiresAt: expirationTime.UTC(),
}
ak := &ArtifactKey{
PrivateKey{
Key: priv,
Metadata: metadata,
},
}
// Marshal PrivateKey struct to JSON
privJSON, err := json.Marshal(ak.PrivateKey)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
}
// Marshal PublicKey struct to JSON
pubKey := PublicKey{
Key: pub,
Metadata: metadata,
}
pubJSON, err := json.Marshal(pubKey)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
}
// Encode to PEM with metadata embedded in bytes
privPEM := pem.EncodeToMemory(&pem.Block{
Type: tagArtifactPrivate,
Bytes: privJSON,
})
pubPEM := pem.EncodeToMemory(&pem.Block{
Type: tagArtifactPublic,
Bytes: pubJSON,
})
// Sign the public key with the root key
signature, err := SignArtifactKey(*rootKey, pubPEM)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to sign artifact key: %w", err)
}
return ak, privPEM, pubPEM, signature, nil
}
func ParseArtifactKey(privKeyPEM []byte) (ArtifactKey, error) {
pk, err := parsePrivateKey(privKeyPEM, tagArtifactPrivate)
if err != nil {
return ArtifactKey{}, fmt.Errorf("failed to parse artifact Key: %w", err)
}
return ArtifactKey{pk}, nil
}
func ParseArtifactPubKey(data []byte) (PublicKey, error) {
pk, _, err := parsePublicKey(data, tagArtifactPublic)
return pk, err
}
func BundleArtifactKeys(rootKey *RootKey, keys []PublicKey) ([]byte, []byte, error) {
if len(keys) == 0 {
return nil, nil, errors.New("no keys to bundle")
}
// Create bundle by concatenating PEM-encoded keys
var pubBundle []byte
for _, pk := range keys {
// Marshal PublicKey struct to JSON
pubJSON, err := json.Marshal(pk)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
}
// Encode to PEM
pubPEM := pem.EncodeToMemory(&pem.Block{
Type: tagArtifactPublic,
Bytes: pubJSON,
})
pubBundle = append(pubBundle, pubPEM...)
}
// Sign the entire bundle with the root key
signature, err := SignArtifactKey(*rootKey, pubBundle)
if err != nil {
return nil, nil, fmt.Errorf("failed to sign artifact key bundle: %w", err)
}
return pubBundle, signature, nil
}
func ValidateArtifactKeys(publicRootKeys []PublicKey, data []byte, signature Signature, revocationList *RevocationList) ([]PublicKey, error) {
now := time.Now().UTC()
if signature.Timestamp.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("signature timestamp is in the future: %v", signature.Timestamp)
log.Debugf("artifact signature error: %v", err)
return nil, err
}
if now.Sub(signature.Timestamp) > maxArtifactKeySignatureAge {
err := fmt.Errorf("signature is too old: %v (created %v)", now.Sub(signature.Timestamp), signature.Timestamp)
log.Debugf("artifact signature error: %v", err)
return nil, err
}
// Reconstruct the signed message: artifact_key_data || timestamp
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
if !verifyAny(publicRootKeys, msg, signature.Signature) {
return nil, errors.New("failed to verify signature of artifact keys")
}
pubKeys, err := parsePublicKeyBundle(data, tagArtifactPublic)
if err != nil {
log.Debugf("failed to parse public keys: %s", err)
return nil, err
}
validKeys := make([]PublicKey, 0, len(pubKeys))
for _, pubKey := range pubKeys {
// Filter out expired keys
if !pubKey.Metadata.ExpiresAt.IsZero() && now.After(pubKey.Metadata.ExpiresAt) {
log.Debugf("Key %s is expired at %v (current time %v)",
pubKey.Metadata.ID, pubKey.Metadata.ExpiresAt, now)
continue
}
if revocationList != nil {
if revTime, revoked := revocationList.Revoked[pubKey.Metadata.ID]; revoked {
log.Debugf("Key %s is revoked as of %v (created %v)",
pubKey.Metadata.ID, revTime, pubKey.Metadata.CreatedAt)
continue
}
}
validKeys = append(validKeys, pubKey)
}
if len(validKeys) == 0 {
log.Debugf("no valid public keys found for artifact keys")
return nil, fmt.Errorf("all %d artifact keys are revoked", len(pubKeys))
}
return validKeys, nil
}
func ValidateArtifact(artifactPubKeys []PublicKey, data []byte, signature Signature) error {
// Validate signature timestamp
now := time.Now().UTC()
if signature.Timestamp.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("artifact signature timestamp is in the future: %v", signature.Timestamp)
log.Debugf("failed to verify signature of artifact: %s", err)
return err
}
if now.Sub(signature.Timestamp) > maxArtifactSignatureAge {
return fmt.Errorf("artifact signature is too old: %v (created %v)",
now.Sub(signature.Timestamp), signature.Timestamp)
}
h := NewArtifactHash()
if _, err := h.Write(data); err != nil {
return fmt.Errorf("failed to hash artifact: %w", err)
}
hash := h.Sum(nil)
// Reconstruct the signed message: hash || length || timestamp
msg := make([]byte, 0, len(hash)+8+8)
msg = append(msg, hash...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
// Find matching Key and verify
for _, keyInfo := range artifactPubKeys {
if keyInfo.Metadata.ID == signature.KeyID {
// Check Key expiration
if !keyInfo.Metadata.ExpiresAt.IsZero() &&
signature.Timestamp.After(keyInfo.Metadata.ExpiresAt) {
return fmt.Errorf("signing Key %s expired at %v, signature from %v",
signature.KeyID, keyInfo.Metadata.ExpiresAt, signature.Timestamp)
}
if ed25519.Verify(keyInfo.Key, msg, signature.Signature) {
log.Debugf("artifact verified successfully with Key: %s", signature.KeyID)
return nil
}
return fmt.Errorf("signature verification failed for Key %s", signature.KeyID)
}
}
return fmt.Errorf("no signing Key found with ID %s", signature.KeyID)
}
func SignData(artifactKey ArtifactKey, data []byte) ([]byte, error) {
if len(data) == 0 { // Check happens too late
return nil, fmt.Errorf("artifact length must be positive, got %d", len(data))
}
h := NewArtifactHash()
if _, err := h.Write(data); err != nil {
return nil, fmt.Errorf("failed to write artifact hash: %w", err)
}
timestamp := time.Now().UTC()
if !artifactKey.Metadata.ExpiresAt.IsZero() && timestamp.After(artifactKey.Metadata.ExpiresAt) {
return nil, fmt.Errorf("artifact key expired at %v", artifactKey.Metadata.ExpiresAt)
}
hash := h.Sum(nil)
// Create message: hash || length || timestamp
msg := make([]byte, 0, len(hash)+8+8)
msg = append(msg, hash...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
sig := ed25519.Sign(artifactKey.Key, msg)
bundle := Signature{
Signature: sig,
Timestamp: timestamp,
KeyID: artifactKey.Metadata.ID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
return json.Marshal(bundle)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
-----BEGIN ROOT PUBLIC KEY-----
eyJLZXkiOiJoaGIxdGRDSEZNMFBuQWp1b2w2cXJ1QXRFbWFFSlg1QjFsZUNxWmpn
V1pvPSIsIk1ldGFkYXRhIjp7ImlkIjoiOWE0OTg2NmI2MzE2MjNiNCIsImNyZWF0
ZWRfYXQiOiIyMDI1LTExLTI0VDE3OjE1OjI4LjYyNzE3MzE3MVoiLCJleHBpcmVz
X2F0IjoiMjAzNS0xMS0yMlQxNzoxNToyOC42MjcxNzMxNzFaIn19
-----END ROOT PUBLIC KEY-----

View File

@@ -0,0 +1,6 @@
-----BEGIN ROOT PUBLIC KEY-----
eyJLZXkiOiJyTDByVTN2MEFOZUNmbDZraitiUUd3TE1waU5CaUJLdVBWSnZtQzgr
ZS84PSIsIk1ldGFkYXRhIjp7ImlkIjoiMTBkNjQyZTY2N2FmMDNkNCIsImNyZWF0
ZWRfYXQiOiIyMDI1LTExLTIwVDE3OjI5OjI5LjE4MDk0NjMxNloiLCJleHBpcmVz
X2F0IjoiMjAyNi0xMS0yMFQxNzoyOToyOS4xODA5NDYzMTZaIn19
-----END ROOT PUBLIC KEY-----

View File

@@ -0,0 +1,174 @@
// Package reposign implements a cryptographic signing and verification system
// for NetBird software update artifacts. It provides a hierarchical key
// management system with support for key rotation, revocation, and secure
// artifact distribution.
//
// # Architecture
//
// The package uses a two-tier key hierarchy:
//
// - Root Keys: Long-lived keys that sign artifact keys. These are embedded
// in the client binary and establish the root of trust. Root keys should
// be kept offline and highly secured.
//
// - Artifact Keys: Short-lived keys that sign release artifacts (binaries,
// packages, etc.). These are rotated regularly and can be revoked if
// compromised. Artifact keys are signed by root keys and distributed via
// a public repository.
//
// This separation allows for operational flexibility: artifact keys can be
// rotated frequently without requiring client updates, while root keys remain
// stable and embedded in the software.
//
// # Cryptographic Primitives
//
// The package uses strong, modern cryptographic algorithms:
// - Ed25519: Fast, secure digital signatures (no timing attacks)
// - BLAKE2s-256: Fast cryptographic hash for artifacts
// - SHA-256: Key ID generation
// - JSON: Structured key and signature serialization
// - PEM: Standard key encoding format
//
// # Security Features
//
// Timestamp Binding:
// - All signatures include cryptographically-bound timestamps
// - Prevents replay attacks and enforces signature freshness
// - Clock skew tolerance: 5 minutes
//
// Key Expiration:
// - All keys have expiration times
// - Expired keys are automatically rejected
// - Signing with an expired key fails immediately
//
// Key Revocation:
// - Compromised keys can be revoked via a signed revocation list
// - Revocation list is checked during artifact validation
// - Revoked keys are filtered out before artifact verification
//
// # File Structure
//
// The package expects the following file layout in the key repository:
//
// signrepo/
// artifact-key-pub.pem # Bundle of artifact public keys
// artifact-key-pub.pem.sig # Root signature of the bundle
// revocation-list.json # List of revoked key IDs
// revocation-list.json.sig # Root signature of revocation list
//
// And in the artifacts repository:
//
// releases/
// v0.28.0/
// netbird-linux-amd64
// netbird-linux-amd64.sig # Artifact signature
// netbird-darwin-amd64
// netbird-darwin-amd64.sig
// ...
//
// # Embedded Root Keys
//
// Root public keys are embedded in the client binary at compile time:
// - Production keys: certs/ directory
// - Development keys: certsdev/ directory
//
// The build tag determines which keys are embedded:
// - Production builds: //go:build !devartifactsign
// - Development builds: //go:build devartifactsign
//
// This ensures that development artifacts cannot be verified using production
// keys and vice versa.
//
// # Key Rotation Strategies
//
// Root Key Rotation:
//
// Root keys can be rotated without breaking existing clients by leveraging
// the multi-key verification system. The loadEmbeddedPublicKeys function
// reads ALL files from the certs/ directory and accepts signatures from ANY
// of the embedded root keys.
//
// To rotate root keys:
//
// 1. Generate a new root key pair:
// newRootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
//
// 2. Add the new public key to the certs/ directory as a new file:
// certs/
// root-pub-2024.pem # Old key (keep this!)
// root-pub-2025.pem # New key (add this)
//
// 3. Build new client versions with both keys embedded. The verification
// will accept signatures from either key.
//
// 4. Start signing new artifact keys with the new root key. Old clients
// with only the old root key will reject these, but new clients with
// both keys will accept them.
//
// Each file in certs/ can contain a single key or a bundle of keys (multiple
// PEM blocks). The system will parse all keys from all files and use them
// for verification. This provides maximum flexibility for key management.
//
// Important: Never remove all old root keys at once. Always maintain at least
// one overlapping key between releases to ensure smooth transitions.
//
// Artifact Key Rotation:
//
// Artifact keys should be rotated regularly (e.g., every 90 days) using the
// bundling mechanism. The BundleArtifactKeys function allows multiple artifact
// keys to be bundled together in a single signed package, and ValidateArtifact
// will accept signatures from ANY key in the bundle.
//
// To rotate artifact keys smoothly:
//
// 1. Generate a new artifact key while keeping the old one:
// newKey, newPrivPEM, newPubPEM, newSig, err := GenerateArtifactKey(rootKey, 90 * 24 * time.Hour)
// // Keep oldPubPEM and oldKey available
//
// 2. Create a bundle containing both old and new public keys
//
// 3. Upload the bundle and its signature to the key repository:
// signrepo/artifact-key-pub.pem # Contains both keys
// signrepo/artifact-key-pub.pem.sig # Root signature
//
// 4. Start signing new releases with the NEW key, but keep the bundle
// unchanged. Clients will download the bundle (containing both keys)
// and accept signatures from either key.
//
// Key bundle validation workflow:
// 1. Client downloads artifact-key-pub.pem and artifact-key-pub.pem.sig
// 2. ValidateArtifactKeys verifies the bundle signature with ANY embedded root key
// 3. ValidateArtifactKeys parses all public keys from the bundle
// 4. ValidateArtifactKeys filters out expired or revoked keys
// 5. When verifying an artifact, ValidateArtifact tries each key until one succeeds
//
// This multi-key acceptance model enables overlapping validity periods and
// smooth transitions without client update requirements.
//
// # Best Practices
//
// Root Key Management:
// - Generate root keys offline on an air-gapped machine
// - Store root private keys in hardware security modules (HSM) if possible
// - Use separate root keys for production and development
// - Rotate root keys infrequently (e.g., every 5-10 years)
// - Plan for root key rotation: embed multiple root public keys
//
// Artifact Key Management:
// - Rotate artifact keys regularly (e.g., every 90 days)
// - Use separate artifact keys for different release channels if needed
// - Revoke keys immediately upon suspected compromise
// - Bundle multiple artifact keys to enable smooth rotation
//
// Signing Process:
// - Sign artifacts in a secure CI/CD environment
// - Never commit private keys to version control
// - Use environment variables or secret management for keys
// - Verify signatures immediately after signing
//
// Distribution:
// - Serve keys and revocation lists from a reliable CDN
// - Use HTTPS for all key and artifact downloads
// - Monitor download failures and signature verification failures
// - Keep revocation list up to date
package reposign

View File

@@ -0,0 +1,10 @@
//go:build devartifactsign
package reposign
import "embed"
//go:embed certsdev
var embeddedCerts embed.FS
const embeddedCertsDir = "certsdev"

View File

@@ -0,0 +1,10 @@
//go:build !devartifactsign
package reposign
import "embed"
//go:embed certs
var embeddedCerts embed.FS
const embeddedCertsDir = "certs"

View File

@@ -0,0 +1,171 @@
package reposign
import (
"crypto/ed25519"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"time"
)
const (
maxClockSkew = 5 * time.Minute
)
// KeyID is a unique identifier for a Key (first 8 bytes of SHA-256 of public Key)
type KeyID [8]byte
// computeKeyID generates a unique ID from a public Key
func computeKeyID(pub ed25519.PublicKey) KeyID {
h := sha256.Sum256(pub)
var id KeyID
copy(id[:], h[:8])
return id
}
// MarshalJSON implements json.Marshaler for KeyID
func (k KeyID) MarshalJSON() ([]byte, error) {
return json.Marshal(k.String())
}
// UnmarshalJSON implements json.Unmarshaler for KeyID
func (k *KeyID) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
parsed, err := ParseKeyID(s)
if err != nil {
return err
}
*k = parsed
return nil
}
// ParseKeyID parses a hex string (16 hex chars = 8 bytes) into a KeyID.
func ParseKeyID(s string) (KeyID, error) {
var id KeyID
if len(s) != 16 {
return id, fmt.Errorf("invalid KeyID length: got %d, want 16 hex chars (8 bytes)", len(s))
}
b, err := hex.DecodeString(s)
if err != nil {
return id, fmt.Errorf("failed to decode KeyID: %w", err)
}
copy(id[:], b)
return id, nil
}
func (k KeyID) String() string {
return fmt.Sprintf("%x", k[:])
}
// KeyMetadata contains versioning and lifecycle information for a Key
type KeyMetadata struct {
ID KeyID `json:"id"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at,omitempty"` // Optional expiration
}
// PublicKey wraps a public Key with its Metadata
type PublicKey struct {
Key ed25519.PublicKey
Metadata KeyMetadata
}
func parsePublicKeyBundle(bundle []byte, typeTag string) ([]PublicKey, error) {
var keys []PublicKey
for len(bundle) > 0 {
keyInfo, rest, err := parsePublicKey(bundle, typeTag)
if err != nil {
return nil, err
}
keys = append(keys, keyInfo)
bundle = rest
}
if len(keys) == 0 {
return nil, errors.New("no keys found in bundle")
}
return keys, nil
}
func parsePublicKey(data []byte, typeTag string) (PublicKey, []byte, error) {
b, rest := pem.Decode(data)
if b == nil {
return PublicKey{}, nil, errors.New("failed to decode PEM data")
}
if b.Type != typeTag {
return PublicKey{}, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
}
// Unmarshal JSON-embedded format
var pub PublicKey
if err := json.Unmarshal(b.Bytes, &pub); err != nil {
return PublicKey{}, nil, fmt.Errorf("failed to unmarshal public key: %w", err)
}
// Validate key length
if len(pub.Key) != ed25519.PublicKeySize {
return PublicKey{}, nil, fmt.Errorf("incorrect Ed25519 public key size: expected %d, got %d",
ed25519.PublicKeySize, len(pub.Key))
}
// Always recompute ID to ensure integrity
pub.Metadata.ID = computeKeyID(pub.Key)
return pub, rest, nil
}
type PrivateKey struct {
Key ed25519.PrivateKey
Metadata KeyMetadata
}
func parsePrivateKey(data []byte, typeTag string) (PrivateKey, error) {
b, rest := pem.Decode(data)
if b == nil {
return PrivateKey{}, errors.New("failed to decode PEM data")
}
if len(rest) > 0 {
return PrivateKey{}, errors.New("trailing PEM data")
}
if b.Type != typeTag {
return PrivateKey{}, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
}
// Unmarshal JSON-embedded format
var pk PrivateKey
if err := json.Unmarshal(b.Bytes, &pk); err != nil {
return PrivateKey{}, fmt.Errorf("failed to unmarshal private key: %w", err)
}
// Validate key length
if len(pk.Key) != ed25519.PrivateKeySize {
return PrivateKey{}, fmt.Errorf("incorrect Ed25519 private key size: expected %d, got %d",
ed25519.PrivateKeySize, len(pk.Key))
}
return pk, nil
}
func verifyAny(publicRootKeys []PublicKey, msg, sig []byte) bool {
// Verify with root keys
var rootKeys []ed25519.PublicKey
for _, r := range publicRootKeys {
rootKeys = append(rootKeys, r.Key)
}
for _, k := range rootKeys {
if ed25519.Verify(k, msg, sig) {
return true
}
}
return false
}

View File

@@ -0,0 +1,636 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/json"
"encoding/pem"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test KeyID functions
func TestComputeKeyID(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
// Verify it's the first 8 bytes of SHA-256
h := sha256.Sum256(pub)
expectedID := KeyID{}
copy(expectedID[:], h[:8])
assert.Equal(t, expectedID, keyID)
}
func TestComputeKeyID_Deterministic(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
// Computing KeyID multiple times should give the same result
keyID1 := computeKeyID(pub)
keyID2 := computeKeyID(pub)
assert.Equal(t, keyID1, keyID2)
}
func TestComputeKeyID_DifferentKeys(t *testing.T) {
pub1, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID1 := computeKeyID(pub1)
keyID2 := computeKeyID(pub2)
// Different keys should produce different IDs
assert.NotEqual(t, keyID1, keyID2)
}
func TestParseKeyID_Valid(t *testing.T) {
hexStr := "0123456789abcdef"
keyID, err := ParseKeyID(hexStr)
require.NoError(t, err)
expected := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
assert.Equal(t, expected, keyID)
}
func TestParseKeyID_InvalidLength(t *testing.T) {
tests := []struct {
name string
input string
}{
{"too short", "01234567"},
{"too long", "0123456789abcdef00"},
{"empty", ""},
{"odd length", "0123456789abcde"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ParseKeyID(tt.input)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid KeyID length")
})
}
}
func TestParseKeyID_InvalidHex(t *testing.T) {
invalidHex := "0123456789abcxyz" // 'xyz' are not valid hex
_, err := ParseKeyID(invalidHex)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode KeyID")
}
func TestKeyID_String(t *testing.T) {
keyID := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
str := keyID.String()
assert.Equal(t, "0123456789abcdef", str)
}
func TestKeyID_RoundTrip(t *testing.T) {
original := "fedcba9876543210"
keyID, err := ParseKeyID(original)
require.NoError(t, err)
result := keyID.String()
assert.Equal(t, original, result)
}
func TestKeyID_ZeroValue(t *testing.T) {
keyID := KeyID{}
str := keyID.String()
assert.Equal(t, "0000000000000000", str)
}
// Test KeyMetadata
func TestKeyMetadata_JSONMarshaling(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
ExpiresAt: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC),
}
jsonData, err := json.Marshal(metadata)
require.NoError(t, err)
var decoded KeyMetadata
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Equal(t, metadata.ID, decoded.ID)
assert.Equal(t, metadata.CreatedAt.Unix(), decoded.CreatedAt.Unix())
assert.Equal(t, metadata.ExpiresAt.Unix(), decoded.ExpiresAt.Unix())
}
func TestKeyMetadata_NoExpiration(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
ExpiresAt: time.Time{}, // Zero value = no expiration
}
jsonData, err := json.Marshal(metadata)
require.NoError(t, err)
var decoded KeyMetadata
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.True(t, decoded.ExpiresAt.IsZero())
}
// Test PublicKey
func TestPublicKey_JSONMarshaling(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
var decoded PublicKey
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Equal(t, pubKey.Key, decoded.Key)
assert.Equal(t, pubKey.Metadata.ID, decoded.Metadata.ID)
}
// Test parsePublicKey
func TestParsePublicKey_Valid(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
}
pubKey := PublicKey{
Key: pub,
Metadata: metadata,
}
// Marshal to JSON
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
// Encode to PEM
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
// Parse it back
parsed, rest, err := parsePublicKey(pemData, tagRootPublic)
require.NoError(t, err)
assert.Empty(t, rest)
assert.Equal(t, pub, parsed.Key)
assert.Equal(t, metadata.ID, parsed.Metadata.ID)
}
func TestParsePublicKey_InvalidPEM(t *testing.T) {
invalidPEM := []byte("not a PEM")
_, _, err := parsePublicKey(invalidPEM, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode PEM")
}
func TestParsePublicKey_WrongType(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
// Encode with wrong type
pemData := pem.EncodeToMemory(&pem.Block{
Type: "WRONG TYPE",
Bytes: jsonData,
})
_, _, err = parsePublicKey(pemData, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "PEM type")
}
func TestParsePublicKey_InvalidJSON(t *testing.T) {
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: []byte("invalid json"),
})
_, _, err := parsePublicKey(pemData, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to unmarshal")
}
func TestParsePublicKey_InvalidKeySize(t *testing.T) {
// Create a public key with wrong size
pubKey := PublicKey{
Key: []byte{0x01, 0x02, 0x03}, // Too short
Metadata: KeyMetadata{
ID: KeyID{},
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
_, _, err = parsePublicKey(pemData, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "incorrect Ed25519 public key size")
}
func TestParsePublicKey_IDRecomputation(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
// Create a public key with WRONG ID
wrongID := KeyID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: wrongID,
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
// Parse should recompute the correct ID
parsed, _, err := parsePublicKey(pemData, tagRootPublic)
require.NoError(t, err)
correctID := computeKeyID(pub)
assert.Equal(t, correctID, parsed.Metadata.ID)
assert.NotEqual(t, wrongID, parsed.Metadata.ID)
}
// Test parsePublicKeyBundle
func TestParsePublicKeyBundle_Single(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
keys, err := parsePublicKeyBundle(pemData, tagRootPublic)
require.NoError(t, err)
assert.Len(t, keys, 1)
assert.Equal(t, pub, keys[0].Key)
}
func TestParsePublicKeyBundle_Multiple(t *testing.T) {
var bundle []byte
// Create 3 keys
for i := 0; i < 3; i++ {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
bundle = append(bundle, pemData...)
}
keys, err := parsePublicKeyBundle(bundle, tagRootPublic)
require.NoError(t, err)
assert.Len(t, keys, 3)
}
func TestParsePublicKeyBundle_Empty(t *testing.T) {
_, err := parsePublicKeyBundle([]byte{}, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no keys found")
}
func TestParsePublicKeyBundle_Invalid(t *testing.T) {
_, err := parsePublicKeyBundle([]byte("invalid data"), tagRootPublic)
assert.Error(t, err)
}
// Test PrivateKey
func TestPrivateKey_JSONMarshaling(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
var decoded PrivateKey
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Equal(t, privKey.Key, decoded.Key)
assert.Equal(t, privKey.Metadata.ID, decoded.Metadata.ID)
}
// Test parsePrivateKey
func TestParsePrivateKey_Valid(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: jsonData,
})
parsed, err := parsePrivateKey(pemData, tagRootPrivate)
require.NoError(t, err)
assert.Equal(t, priv, parsed.Key)
}
func TestParsePrivateKey_InvalidPEM(t *testing.T) {
_, err := parsePrivateKey([]byte("not a PEM"), tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode PEM")
}
func TestParsePrivateKey_TrailingData(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: jsonData,
})
// Add trailing data
pemData = append(pemData, []byte("extra data")...)
_, err = parsePrivateKey(pemData, tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "trailing PEM data")
}
func TestParsePrivateKey_WrongType(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: "WRONG TYPE",
Bytes: jsonData,
})
_, err = parsePrivateKey(pemData, tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "PEM type")
}
func TestParsePrivateKey_InvalidKeySize(t *testing.T) {
privKey := PrivateKey{
Key: []byte{0x01, 0x02, 0x03}, // Too short
Metadata: KeyMetadata{
ID: KeyID{},
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: jsonData,
})
_, err = parsePrivateKey(pemData, tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
}
// Test verifyAny
func TestVerifyAny_ValidSignature(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv, message)
rootKeys := []PublicKey{
{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
},
}
result := verifyAny(rootKeys, message, signature)
assert.True(t, result)
}
func TestVerifyAny_InvalidSignature(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
invalidSignature := make([]byte, ed25519.SignatureSize)
rootKeys := []PublicKey{
{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
},
}
result := verifyAny(rootKeys, message, invalidSignature)
assert.False(t, result)
}
func TestVerifyAny_MultipleKeys(t *testing.T) {
// Create 3 key pairs
pub1, priv1, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub3, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv1, message)
rootKeys := []PublicKey{
{Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
{Key: pub1, Metadata: KeyMetadata{ID: computeKeyID(pub1)}}, // Correct key in middle
{Key: pub3, Metadata: KeyMetadata{ID: computeKeyID(pub3)}},
}
result := verifyAny(rootKeys, message, signature)
assert.True(t, result)
}
func TestVerifyAny_NoMatchingKey(t *testing.T) {
_, priv1, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv1, message)
// Only include pub2, not pub1
rootKeys := []PublicKey{
{Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
}
result := verifyAny(rootKeys, message, signature)
assert.False(t, result)
}
func TestVerifyAny_EmptyKeys(t *testing.T) {
message := []byte("test message")
signature := make([]byte, ed25519.SignatureSize)
result := verifyAny([]PublicKey{}, message, signature)
assert.False(t, result)
}
func TestVerifyAny_TamperedMessage(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv, message)
rootKeys := []PublicKey{
{Key: pub, Metadata: KeyMetadata{ID: computeKeyID(pub)}},
}
// Verify with different message
tamperedMessage := []byte("different message")
result := verifyAny(rootKeys, tamperedMessage, signature)
assert.False(t, result)
}

View File

@@ -0,0 +1,229 @@
package reposign
import (
"crypto/ed25519"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"time"
log "github.com/sirupsen/logrus"
)
const (
maxRevocationSignatureAge = 10 * 365 * 24 * time.Hour
defaultRevocationListExpiration = 365 * 24 * time.Hour
)
type RevocationList struct {
Revoked map[KeyID]time.Time `json:"revoked"` // KeyID -> revocation time
LastUpdated time.Time `json:"last_updated"` // When the list was last modified
ExpiresAt time.Time `json:"expires_at"` // When the list expires
}
func (rl RevocationList) MarshalJSON() ([]byte, error) {
// Convert map[KeyID]time.Time to map[string]time.Time
strMap := make(map[string]time.Time, len(rl.Revoked))
for k, v := range rl.Revoked {
strMap[k.String()] = v
}
return json.Marshal(map[string]interface{}{
"revoked": strMap,
"last_updated": rl.LastUpdated,
"expires_at": rl.ExpiresAt,
})
}
func (rl *RevocationList) UnmarshalJSON(data []byte) error {
var temp struct {
Revoked map[string]time.Time `json:"revoked"`
LastUpdated time.Time `json:"last_updated"`
ExpiresAt time.Time `json:"expires_at"`
Version int `json:"version"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Convert map[string]time.Time back to map[KeyID]time.Time
rl.Revoked = make(map[KeyID]time.Time, len(temp.Revoked))
for k, v := range temp.Revoked {
kid, err := ParseKeyID(k)
if err != nil {
return fmt.Errorf("failed to parse KeyID %q: %w", k, err)
}
rl.Revoked[kid] = v
}
rl.LastUpdated = temp.LastUpdated
rl.ExpiresAt = temp.ExpiresAt
return nil
}
func ParseRevocationList(data []byte) (*RevocationList, error) {
var rl RevocationList
if err := json.Unmarshal(data, &rl); err != nil {
return nil, fmt.Errorf("failed to unmarshal revocation list: %w", err)
}
// Initialize the map if it's nil (in case of empty JSON object)
if rl.Revoked == nil {
rl.Revoked = make(map[KeyID]time.Time)
}
if rl.LastUpdated.IsZero() {
return nil, fmt.Errorf("revocation list missing last_updated timestamp")
}
if rl.ExpiresAt.IsZero() {
return nil, fmt.Errorf("revocation list missing expires_at timestamp")
}
return &rl, nil
}
func ValidateRevocationList(publicRootKeys []PublicKey, data []byte, signature Signature) (*RevocationList, error) {
revoList, err := ParseRevocationList(data)
if err != nil {
log.Debugf("failed to parse revocation list: %s", err)
return nil, err
}
now := time.Now().UTC()
// Validate signature timestamp
if signature.Timestamp.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("revocation signature timestamp is in the future: %v", signature.Timestamp)
log.Debugf("revocation list signature error: %v", err)
return nil, err
}
if now.Sub(signature.Timestamp) > maxRevocationSignatureAge {
err := fmt.Errorf("revocation list signature is too old: %v (created %v)",
now.Sub(signature.Timestamp), signature.Timestamp)
log.Debugf("revocation list signature error: %v", err)
return nil, err
}
// Ensure LastUpdated is not in the future (with clock skew tolerance)
if revoList.LastUpdated.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("revocation list LastUpdated is in the future: %v", revoList.LastUpdated)
log.Errorf("rejecting future-dated revocation list: %v", err)
return nil, err
}
// Check if the revocation list has expired
if now.After(revoList.ExpiresAt) {
err := fmt.Errorf("revocation list expired at %v (current time: %v)", revoList.ExpiresAt, now)
log.Errorf("rejecting expired revocation list: %v", err)
return nil, err
}
// Ensure ExpiresAt is not in the future by more than the expected expiration window
// (allows some clock skew but prevents maliciously long expiration times)
if revoList.ExpiresAt.After(now.Add(maxRevocationSignatureAge)) {
err := fmt.Errorf("revocation list ExpiresAt is too far in the future: %v", revoList.ExpiresAt)
log.Errorf("rejecting revocation list with invalid expiration: %v", err)
return nil, err
}
// Validate signature timestamp is close to LastUpdated
// (prevents signing old lists with new timestamps)
timeDiff := signature.Timestamp.Sub(revoList.LastUpdated).Abs()
if timeDiff > maxClockSkew {
err := fmt.Errorf("signature timestamp %v differs too much from list LastUpdated %v (diff: %v)",
signature.Timestamp, revoList.LastUpdated, timeDiff)
log.Errorf("timestamp mismatch in revocation list: %v", err)
return nil, err
}
// Reconstruct the signed message: revocation_list_data || timestamp || version
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
if !verifyAny(publicRootKeys, msg, signature.Signature) {
return nil, errors.New("revocation list verification failed")
}
return revoList, nil
}
func CreateRevocationList(privateRootKey RootKey, expiration time.Duration) ([]byte, []byte, error) {
now := time.Now()
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: now.UTC(),
ExpiresAt: now.Add(expiration).UTC(),
}
signature, err := signRevocationList(privateRootKey, rl)
if err != nil {
return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err)
}
rlData, err := json.Marshal(&rl)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err)
}
signData, err := json.Marshal(signature)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal signature: %w", err)
}
return rlData, signData, nil
}
func ExtendRevocationList(privateRootKey RootKey, rl RevocationList, kid KeyID, expiration time.Duration) ([]byte, []byte, error) {
now := time.Now().UTC()
rl.Revoked[kid] = now
rl.LastUpdated = now
rl.ExpiresAt = now.Add(expiration)
signature, err := signRevocationList(privateRootKey, rl)
if err != nil {
return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err)
}
rlData, err := json.Marshal(&rl)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err)
}
signData, err := json.Marshal(signature)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal signature: %w", err)
}
return rlData, signData, nil
}
func signRevocationList(privateRootKey RootKey, rl RevocationList) (*Signature, error) {
data, err := json.Marshal(rl)
if err != nil {
return nil, fmt.Errorf("failed to marshal revocation list for signing: %w", err)
}
timestamp := time.Now().UTC()
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
sig := ed25519.Sign(privateRootKey.Key, msg)
signature := &Signature{
Signature: sig,
Timestamp: timestamp,
KeyID: privateRootKey.Metadata.ID,
Algorithm: "ed25519",
HashAlgo: "sha512",
}
return signature, nil
}

View File

@@ -0,0 +1,860 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test RevocationList marshaling/unmarshaling
func TestRevocationList_MarshalJSON(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
expiresAt := time.Date(2024, 4, 15, 11, 0, 0, 0, time.UTC)
rl := &RevocationList{
Revoked: map[KeyID]time.Time{
keyID: revokedTime,
},
LastUpdated: lastUpdated,
ExpiresAt: expiresAt,
}
jsonData, err := json.Marshal(rl)
require.NoError(t, err)
// Verify it can be unmarshaled back
var decoded map[string]interface{}
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Contains(t, decoded, "revoked")
assert.Contains(t, decoded, "last_updated")
assert.Contains(t, decoded, "expires_at")
}
func TestRevocationList_UnmarshalJSON(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
jsonData := map[string]interface{}{
"revoked": map[string]string{
keyID.String(): revokedTime.Format(time.RFC3339),
},
"last_updated": lastUpdated.Format(time.RFC3339),
}
jsonBytes, err := json.Marshal(jsonData)
require.NoError(t, err)
var rl RevocationList
err = json.Unmarshal(jsonBytes, &rl)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 1)
assert.Contains(t, rl.Revoked, keyID)
assert.Equal(t, lastUpdated.Unix(), rl.LastUpdated.Unix())
}
func TestRevocationList_MarshalUnmarshal_Roundtrip(t *testing.T) {
pub1, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID1 := computeKeyID(pub1)
keyID2 := computeKeyID(pub2)
original := &RevocationList{
Revoked: map[KeyID]time.Time{
keyID1: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
keyID2: time.Date(2024, 2, 20, 14, 45, 0, 0, time.UTC),
},
LastUpdated: time.Date(2024, 2, 20, 15, 0, 0, 0, time.UTC),
}
// Marshal
jsonData, err := original.MarshalJSON()
require.NoError(t, err)
// Unmarshal
var decoded RevocationList
err = decoded.UnmarshalJSON(jsonData)
require.NoError(t, err)
// Verify
assert.Len(t, decoded.Revoked, 2)
assert.Equal(t, original.Revoked[keyID1].Unix(), decoded.Revoked[keyID1].Unix())
assert.Equal(t, original.Revoked[keyID2].Unix(), decoded.Revoked[keyID2].Unix())
assert.Equal(t, original.LastUpdated.Unix(), decoded.LastUpdated.Unix())
}
func TestRevocationList_UnmarshalJSON_InvalidKeyID(t *testing.T) {
jsonData := []byte(`{
"revoked": {
"invalid_key_id": "2024-01-15T10:30:00Z"
},
"last_updated": "2024-01-15T11:00:00Z"
}`)
var rl RevocationList
err := json.Unmarshal(jsonData, &rl)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse KeyID")
}
func TestRevocationList_EmptyRevoked(t *testing.T) {
rl := &RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: time.Now().UTC(),
}
jsonData, err := rl.MarshalJSON()
require.NoError(t, err)
var decoded RevocationList
err = decoded.UnmarshalJSON(jsonData)
require.NoError(t, err)
assert.Empty(t, decoded.Revoked)
assert.NotNil(t, decoded.Revoked)
}
// Test ParseRevocationList
func TestParseRevocationList_Valid(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
rl := RevocationList{
Revoked: map[KeyID]time.Time{
keyID: revokedTime,
},
LastUpdated: lastUpdated,
ExpiresAt: time.Date(2025, 2, 20, 14, 45, 0, 0, time.UTC),
}
jsonData, err := rl.MarshalJSON()
require.NoError(t, err)
parsed, err := ParseRevocationList(jsonData)
require.NoError(t, err)
assert.NotNil(t, parsed)
assert.Len(t, parsed.Revoked, 1)
assert.Equal(t, lastUpdated.Unix(), parsed.LastUpdated.Unix())
}
func TestParseRevocationList_InvalidJSON(t *testing.T) {
invalidJSON := []byte("not valid json")
_, err := ParseRevocationList(invalidJSON)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to unmarshal")
}
func TestParseRevocationList_MissingLastUpdated(t *testing.T) {
jsonData := []byte(`{
"revoked": {}
}`)
_, err := ParseRevocationList(jsonData)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing last_updated")
}
func TestParseRevocationList_EmptyObject(t *testing.T) {
jsonData := []byte(`{}`)
_, err := ParseRevocationList(jsonData)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing last_updated")
}
func TestParseRevocationList_NilRevoked(t *testing.T) {
lastUpdated := time.Now().UTC()
expiresAt := lastUpdated.Add(90 * 24 * time.Hour)
jsonData := []byte(`{
"last_updated": "` + lastUpdated.Format(time.RFC3339) + `",
"expires_at": "` + expiresAt.Format(time.RFC3339) + `"
}`)
parsed, err := ParseRevocationList(jsonData)
require.NoError(t, err)
assert.NotNil(t, parsed.Revoked)
assert.Empty(t, parsed.Revoked)
}
func TestParseRevocationList_MissingExpiresAt(t *testing.T) {
lastUpdated := time.Now().UTC()
jsonData := []byte(`{
"revoked": {},
"last_updated": "` + lastUpdated.Format(time.RFC3339) + `"
}`)
_, err := ParseRevocationList(jsonData)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing expires_at")
}
// Test ValidateRevocationList
func TestValidateRevocationList_Valid(t *testing.T) {
// Generate root key
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
signature, err := ParseSignature(sigData)
require.NoError(t, err)
// Validate
rl, err := ValidateRevocationList(rootKeys, rlData, *signature)
require.NoError(t, err)
assert.NotNil(t, rl)
assert.Empty(t, rl.Revoked)
}
func TestValidateRevocationList_InvalidSignature(t *testing.T) {
// Generate root key
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
// Create invalid signature
invalidSig := Signature{
Signature: make([]byte, 64),
Timestamp: time.Now().UTC(),
KeyID: computeKeyID(rootPub),
Algorithm: "ed25519",
HashAlgo: "sha512",
}
// Validate should fail
_, err = ValidateRevocationList(rootKeys, rlData, invalidSig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "verification failed")
}
func TestValidateRevocationList_FutureTimestamp(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
signature, err := ParseSignature(sigData)
require.NoError(t, err)
// Modify timestamp to be in the future
signature.Timestamp = time.Now().UTC().Add(10 * time.Minute)
_, err = ValidateRevocationList(rootKeys, rlData, *signature)
assert.Error(t, err)
assert.Contains(t, err.Error(), "in the future")
}
func TestValidateRevocationList_TooOld(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
signature, err := ParseSignature(sigData)
require.NoError(t, err)
// Modify timestamp to be too old
signature.Timestamp = time.Now().UTC().Add(-20 * 365 * 24 * time.Hour)
_, err = ValidateRevocationList(rootKeys, rlData, *signature)
assert.Error(t, err)
assert.Contains(t, err.Error(), "too old")
}
func TestValidateRevocationList_InvalidJSON(t *testing.T) {
rootPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
signature := Signature{
Signature: make([]byte, 64),
Timestamp: time.Now().UTC(),
KeyID: computeKeyID(rootPub),
Algorithm: "ed25519",
HashAlgo: "sha512",
}
_, err = ValidateRevocationList(rootKeys, []byte("invalid json"), signature)
assert.Error(t, err)
}
func TestValidateRevocationList_FutureLastUpdated(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list with future LastUpdated
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: time.Now().UTC().Add(10 * time.Minute),
ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour),
}
rlData, err := json.Marshal(rl)
require.NoError(t, err)
// Sign it
sig, err := signRevocationList(rootKey, rl)
require.NoError(t, err)
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "LastUpdated is in the future")
}
func TestValidateRevocationList_TimestampMismatch(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list with LastUpdated far in the past
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: time.Now().UTC().Add(-1 * time.Hour),
ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour),
}
rlData, err := json.Marshal(rl)
require.NoError(t, err)
// Sign it with current timestamp
sig, err := signRevocationList(rootKey, rl)
require.NoError(t, err)
// Modify signature timestamp to differ too much from LastUpdated
sig.Timestamp = time.Now().UTC()
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "differs too much")
}
func TestValidateRevocationList_Expired(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list that expired in the past
now := time.Now().UTC()
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: now.Add(-100 * 24 * time.Hour),
ExpiresAt: now.Add(-10 * 24 * time.Hour), // Expired 10 days ago
}
rlData, err := json.Marshal(rl)
require.NoError(t, err)
// Sign it
sig, err := signRevocationList(rootKey, rl)
require.NoError(t, err)
// Adjust signature timestamp to match LastUpdated
sig.Timestamp = rl.LastUpdated
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired")
}
func TestValidateRevocationList_ExpiresAtTooFarInFuture(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list with ExpiresAt too far in the future (beyond maxRevocationSignatureAge)
now := time.Now().UTC()
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: now,
ExpiresAt: now.Add(15 * 365 * 24 * time.Hour), // 15 years in the future
}
rlData, err := json.Marshal(rl)
require.NoError(t, err)
// Sign it
sig, err := signRevocationList(rootKey, rl)
require.NoError(t, err)
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "too far in the future")
}
// Test CreateRevocationList
func TestCreateRevocationList_Valid(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
assert.NotEmpty(t, rlData)
assert.NotEmpty(t, sigData)
// Verify it can be parsed
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
assert.Empty(t, rl.Revoked)
assert.False(t, rl.LastUpdated.IsZero())
// Verify signature can be parsed
sig, err := ParseSignature(sigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
}
// Test ExtendRevocationList
func TestExtendRevocationList_AddKey(t *testing.T) {
// Generate root key
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create empty revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
assert.Empty(t, rl.Revoked)
// Generate a key to revoke
revokedPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
revokedKeyID := computeKeyID(revokedPub)
// Extend the revocation list
newRLData, newSigData, err := ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration)
require.NoError(t, err)
// Verify the new list
newRL, err := ParseRevocationList(newRLData)
require.NoError(t, err)
assert.Len(t, newRL.Revoked, 1)
assert.Contains(t, newRL.Revoked, revokedKeyID)
// Verify signature
sig, err := ParseSignature(newSigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
}
func TestExtendRevocationList_MultipleKeys(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create empty revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
// Add first key
key1Pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
key1ID := computeKeyID(key1Pub)
rlData, _, err = ExtendRevocationList(rootKey, *rl, key1ID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 1)
// Add second key
key2Pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
key2ID := computeKeyID(key2Pub)
rlData, _, err = ExtendRevocationList(rootKey, *rl, key2ID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 2)
assert.Contains(t, rl.Revoked, key1ID)
assert.Contains(t, rl.Revoked, key2ID)
}
func TestExtendRevocationList_DuplicateKey(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create empty revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
// Add a key
keyPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(keyPub)
rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
firstRevocationTime := rl.Revoked[keyID]
// Wait a bit
time.Sleep(10 * time.Millisecond)
// Add the same key again
rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 1)
// The revocation time should be updated
secondRevocationTime := rl.Revoked[keyID]
assert.True(t, secondRevocationTime.After(firstRevocationTime) || secondRevocationTime.Equal(firstRevocationTime))
}
func TestExtendRevocationList_UpdatesLastUpdated(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
firstLastUpdated := rl.LastUpdated
// Wait a bit
time.Sleep(10 * time.Millisecond)
// Extend list
keyPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(keyPub)
rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
// LastUpdated should be updated
assert.True(t, rl.LastUpdated.After(firstLastUpdated))
}
// Integration test
func TestRevocationList_FullWorkflow(t *testing.T) {
// Create root key
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Step 1: Create empty revocation list
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
// Step 2: Validate it
sig, err := ParseSignature(sigData)
require.NoError(t, err)
rl, err := ValidateRevocationList(rootKeys, rlData, *sig)
require.NoError(t, err)
assert.Empty(t, rl.Revoked)
// Step 3: Revoke a key
revokedPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
revokedKeyID := computeKeyID(revokedPub)
rlData, sigData, err = ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration)
require.NoError(t, err)
// Step 4: Validate the extended list
sig, err = ParseSignature(sigData)
require.NoError(t, err)
rl, err = ValidateRevocationList(rootKeys, rlData, *sig)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 1)
assert.Contains(t, rl.Revoked, revokedKeyID)
// Step 5: Verify the revocation time is reasonable
revTime := rl.Revoked[revokedKeyID]
now := time.Now().UTC()
assert.True(t, revTime.Before(now) || revTime.Equal(now))
assert.True(t, now.Sub(revTime) < time.Minute)
}

View File

@@ -0,0 +1,120 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
"encoding/json"
"encoding/pem"
"fmt"
"time"
)
const (
tagRootPrivate = "ROOT PRIVATE KEY"
tagRootPublic = "ROOT PUBLIC KEY"
)
// RootKey is a root Key used to sign signing keys
type RootKey struct {
PrivateKey
}
func (k RootKey) String() string {
return fmt.Sprintf(
"RootKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
k.Metadata.ID,
k.Metadata.CreatedAt.Format(time.RFC3339),
k.Metadata.ExpiresAt.Format(time.RFC3339),
)
}
func ParseRootKey(privKeyPEM []byte) (*RootKey, error) {
pk, err := parsePrivateKey(privKeyPEM, tagRootPrivate)
if err != nil {
return nil, fmt.Errorf("failed to parse root Key: %w", err)
}
return &RootKey{pk}, nil
}
// ParseRootPublicKey parses a root public key from PEM format
func ParseRootPublicKey(pubKeyPEM []byte) (PublicKey, error) {
pk, _, err := parsePublicKey(pubKeyPEM, tagRootPublic)
if err != nil {
return PublicKey{}, fmt.Errorf("failed to parse root public key: %w", err)
}
return pk, nil
}
// GenerateRootKey generates a new root Key pair with Metadata
func GenerateRootKey(expiration time.Duration) (*RootKey, []byte, []byte, error) {
now := time.Now()
expirationTime := now.Add(expiration)
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, nil, err
}
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: now.UTC(),
ExpiresAt: expirationTime.UTC(),
}
rk := &RootKey{
PrivateKey{
Key: priv,
Metadata: metadata,
},
}
// Marshal PrivateKey struct to JSON
privJSON, err := json.Marshal(rk.PrivateKey)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
}
// Marshal PublicKey struct to JSON
pubKey := PublicKey{
Key: pub,
Metadata: metadata,
}
pubJSON, err := json.Marshal(pubKey)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
}
// Encode to PEM with metadata embedded in bytes
privPEM := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: privJSON,
})
pubPEM := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: pubJSON,
})
return rk, privPEM, pubPEM, nil
}
func SignArtifactKey(rootKey RootKey, data []byte) ([]byte, error) {
timestamp := time.Now().UTC()
// This ensures the timestamp is cryptographically bound to the signature
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
sig := ed25519.Sign(rootKey.Key, msg)
// Create signature bundle with timestamp and Metadata
bundle := Signature{
Signature: sig,
Timestamp: timestamp,
KeyID: rootKey.Metadata.ID,
Algorithm: "ed25519",
HashAlgo: "sha512",
}
return json.Marshal(bundle)
}

View File

@@ -0,0 +1,476 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
"encoding/json"
"encoding/pem"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test RootKey.String()
func TestRootKey_String(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
expiresAt := time.Date(2034, 1, 15, 10, 30, 0, 0, time.UTC)
rk := RootKey{
PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: createdAt,
ExpiresAt: expiresAt,
},
},
}
str := rk.String()
assert.Contains(t, str, "RootKey")
assert.Contains(t, str, computeKeyID(pub).String())
assert.Contains(t, str, "2024-01-15")
assert.Contains(t, str, "2034-01-15")
}
func TestRootKey_String_NoExpiration(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
rk := RootKey{
PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: createdAt,
ExpiresAt: time.Time{}, // No expiration
},
},
}
str := rk.String()
assert.Contains(t, str, "RootKey")
assert.Contains(t, str, "0001-01-01") // Zero time format
}
// Test GenerateRootKey
func TestGenerateRootKey_Valid(t *testing.T) {
expiration := 10 * 365 * 24 * time.Hour // 10 years
rk, privPEM, pubPEM, err := GenerateRootKey(expiration)
require.NoError(t, err)
assert.NotNil(t, rk)
assert.NotEmpty(t, privPEM)
assert.NotEmpty(t, pubPEM)
// Verify the key has correct metadata
assert.False(t, rk.Metadata.CreatedAt.IsZero())
assert.False(t, rk.Metadata.ExpiresAt.IsZero())
assert.True(t, rk.Metadata.ExpiresAt.After(rk.Metadata.CreatedAt))
// Verify expiration is approximately correct
expectedExpiration := time.Now().Add(expiration)
timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration)
assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute)
}
func TestGenerateRootKey_ShortExpiration(t *testing.T) {
expiration := 24 * time.Hour // 1 day
rk, _, _, err := GenerateRootKey(expiration)
require.NoError(t, err)
assert.NotNil(t, rk)
// Verify expiration
expectedExpiration := time.Now().Add(expiration)
timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration)
assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute)
}
func TestGenerateRootKey_ZeroExpiration(t *testing.T) {
rk, _, _, err := GenerateRootKey(0)
require.NoError(t, err)
assert.NotNil(t, rk)
// With zero expiration, ExpiresAt should be equal to CreatedAt
assert.Equal(t, rk.Metadata.CreatedAt, rk.Metadata.ExpiresAt)
}
func TestGenerateRootKey_PEMFormat(t *testing.T) {
rk, privPEM, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Verify private key PEM
privBlock, _ := pem.Decode(privPEM)
require.NotNil(t, privBlock)
assert.Equal(t, tagRootPrivate, privBlock.Type)
var privKey PrivateKey
err = json.Unmarshal(privBlock.Bytes, &privKey)
require.NoError(t, err)
assert.Equal(t, rk.Key, privKey.Key)
// Verify public key PEM
pubBlock, _ := pem.Decode(pubPEM)
require.NotNil(t, pubBlock)
assert.Equal(t, tagRootPublic, pubBlock.Type)
var pubKey PublicKey
err = json.Unmarshal(pubBlock.Bytes, &pubKey)
require.NoError(t, err)
assert.Equal(t, rk.Metadata.ID, pubKey.Metadata.ID)
}
func TestGenerateRootKey_KeySize(t *testing.T) {
rk, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Ed25519 private key should be 64 bytes
assert.Equal(t, ed25519.PrivateKeySize, len(rk.Key))
// Ed25519 public key should be 32 bytes
pubKey := rk.Key.Public().(ed25519.PublicKey)
assert.Equal(t, ed25519.PublicKeySize, len(pubKey))
}
func TestGenerateRootKey_UniqueKeys(t *testing.T) {
rk1, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rk2, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Different keys should have different IDs
assert.NotEqual(t, rk1.Metadata.ID, rk2.Metadata.ID)
assert.NotEqual(t, rk1.Key, rk2.Key)
}
// Test ParseRootKey
func TestParseRootKey_Valid(t *testing.T) {
original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
parsed, err := ParseRootKey(privPEM)
require.NoError(t, err)
assert.NotNil(t, parsed)
// Verify the parsed key matches the original
assert.Equal(t, original.Key, parsed.Key)
assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID)
assert.Equal(t, original.Metadata.CreatedAt.Unix(), parsed.Metadata.CreatedAt.Unix())
assert.Equal(t, original.Metadata.ExpiresAt.Unix(), parsed.Metadata.ExpiresAt.Unix())
}
func TestParseRootKey_InvalidPEM(t *testing.T) {
_, err := ParseRootKey([]byte("not a valid PEM"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse")
}
func TestParseRootKey_EmptyData(t *testing.T) {
_, err := ParseRootKey([]byte{})
assert.Error(t, err)
}
func TestParseRootKey_WrongType(t *testing.T) {
// Generate an artifact key instead of root key
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
artifactKey, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
// Try to parse artifact key as root key
_, err = ParseRootKey(privPEM)
assert.Error(t, err)
assert.Contains(t, err.Error(), "PEM type")
// Just to use artifactKey to avoid unused variable warning
_ = artifactKey
}
func TestParseRootKey_CorruptedJSON(t *testing.T) {
// Create PEM with corrupted JSON
corruptedPEM := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: []byte("corrupted json data"),
})
_, err := ParseRootKey(corruptedPEM)
assert.Error(t, err)
}
func TestParseRootKey_InvalidKeySize(t *testing.T) {
// Create a key with invalid size
invalidKey := PrivateKey{
Key: []byte{0x01, 0x02, 0x03}, // Too short
Metadata: KeyMetadata{
ID: KeyID{},
CreatedAt: time.Now().UTC(),
},
}
privJSON, err := json.Marshal(invalidKey)
require.NoError(t, err)
invalidPEM := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: privJSON,
})
_, err = ParseRootKey(invalidPEM)
assert.Error(t, err)
assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
}
func TestParseRootKey_Roundtrip(t *testing.T) {
// Generate a key
original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Parse it
parsed, err := ParseRootKey(privPEM)
require.NoError(t, err)
// Generate PEM again from parsed key
privJSON2, err := json.Marshal(parsed.PrivateKey)
require.NoError(t, err)
privPEM2 := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: privJSON2,
})
// Parse again
parsed2, err := ParseRootKey(privPEM2)
require.NoError(t, err)
// Should still match original
assert.Equal(t, original.Key, parsed2.Key)
assert.Equal(t, original.Metadata.ID, parsed2.Metadata.ID)
}
// Test SignArtifactKey
func TestSignArtifactKey_Valid(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
data := []byte("test data to sign")
sigData, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
assert.NotEmpty(t, sigData)
// Parse and verify signature
sig, err := ParseSignature(sigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
assert.Equal(t, rootKey.Metadata.ID, sig.KeyID)
assert.Equal(t, "ed25519", sig.Algorithm)
assert.Equal(t, "sha512", sig.HashAlgo)
assert.False(t, sig.Timestamp.IsZero())
}
func TestSignArtifactKey_EmptyData(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
sigData, err := SignArtifactKey(*rootKey, []byte{})
require.NoError(t, err)
assert.NotEmpty(t, sigData)
// Should still be able to parse
sig, err := ParseSignature(sigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
}
func TestSignArtifactKey_Verify(t *testing.T) {
rootKey, _, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Parse public key
pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic)
require.NoError(t, err)
// Sign some data
data := []byte("test data for verification")
sigData, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
// Parse signature
sig, err := ParseSignature(sigData)
require.NoError(t, err)
// Reconstruct message
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix()))
// Verify signature
valid := ed25519.Verify(pubKey.Key, msg, sig.Signature)
assert.True(t, valid)
}
func TestSignArtifactKey_DifferentData(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
data1 := []byte("data1")
data2 := []byte("data2")
sig1, err := SignArtifactKey(*rootKey, data1)
require.NoError(t, err)
sig2, err := SignArtifactKey(*rootKey, data2)
require.NoError(t, err)
// Different data should produce different signatures
assert.NotEqual(t, sig1, sig2)
}
func TestSignArtifactKey_MultipleSignatures(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
data := []byte("test data")
// Sign twice with a small delay
sig1, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond)
sig2, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
// Signatures should be different due to different timestamps
assert.NotEqual(t, sig1, sig2)
// Parse both signatures
parsed1, err := ParseSignature(sig1)
require.NoError(t, err)
parsed2, err := ParseSignature(sig2)
require.NoError(t, err)
// Timestamps should be different
assert.True(t, parsed2.Timestamp.After(parsed1.Timestamp))
}
func TestSignArtifactKey_LargeData(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Create 1MB of data
largeData := make([]byte, 1024*1024)
for i := range largeData {
largeData[i] = byte(i % 256)
}
sigData, err := SignArtifactKey(*rootKey, largeData)
require.NoError(t, err)
assert.NotEmpty(t, sigData)
// Verify signature can be parsed
sig, err := ParseSignature(sigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
}
func TestSignArtifactKey_TimestampInSignature(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
beforeSign := time.Now().UTC()
data := []byte("test data")
sigData, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
afterSign := time.Now().UTC()
sig, err := ParseSignature(sigData)
require.NoError(t, err)
// Timestamp should be between before and after
assert.True(t, sig.Timestamp.After(beforeSign.Add(-time.Second)))
assert.True(t, sig.Timestamp.Before(afterSign.Add(time.Second)))
}
// Integration test
func TestRootKey_FullWorkflow(t *testing.T) {
// Step 1: Generate root key
rootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
assert.NotNil(t, rootKey)
assert.NotEmpty(t, privPEM)
assert.NotEmpty(t, pubPEM)
// Step 2: Parse the private key back
parsedRootKey, err := ParseRootKey(privPEM)
require.NoError(t, err)
assert.Equal(t, rootKey.Key, parsedRootKey.Key)
assert.Equal(t, rootKey.Metadata.ID, parsedRootKey.Metadata.ID)
// Step 3: Generate an artifact key using root key
artifactKey, _, artifactPubPEM, artifactSig, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
assert.NotNil(t, artifactKey)
// Step 4: Verify the artifact key signature
pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic)
require.NoError(t, err)
sig, err := ParseSignature(artifactSig)
require.NoError(t, err)
artifactPubKey, _, err := parsePublicKey(artifactPubPEM, tagArtifactPublic)
require.NoError(t, err)
// Reconstruct message - SignArtifactKey signs the PEM, not the JSON
msg := make([]byte, 0, len(artifactPubPEM)+8)
msg = append(msg, artifactPubPEM...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix()))
// Verify with root public key
valid := ed25519.Verify(pubKey.Key, msg, sig.Signature)
assert.True(t, valid, "Artifact key signature should be valid")
// Step 5: Use artifact key to sign data
testData := []byte("This is test artifact data")
dataSig, err := SignData(*artifactKey, testData)
require.NoError(t, err)
assert.NotEmpty(t, dataSig)
// Step 6: Verify the artifact data signature
dataSigParsed, err := ParseSignature(dataSig)
require.NoError(t, err)
err = ValidateArtifact([]PublicKey{artifactPubKey}, testData, *dataSigParsed)
assert.NoError(t, err, "Artifact data signature should be valid")
}
func TestRootKey_ExpiredKeyWorkflow(t *testing.T) {
// Generate a root key that expires very soon
rootKey, _, _, err := GenerateRootKey(1 * time.Millisecond)
require.NoError(t, err)
// Wait for expiration
time.Sleep(10 * time.Millisecond)
// Try to generate artifact key with expired root key
_, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired")
}

View File

@@ -0,0 +1,24 @@
package reposign
import (
"encoding/json"
"time"
)
// Signature contains a signature with associated Metadata
type Signature struct {
Signature []byte `json:"signature"`
Timestamp time.Time `json:"timestamp"`
KeyID KeyID `json:"key_id"`
Algorithm string `json:"algorithm"` // "ed25519"
HashAlgo string `json:"hash_algo"` // "blake2s" or sha512
}
func ParseSignature(data []byte) (*Signature, error) {
var signature Signature
if err := json.Unmarshal(data, &signature); err != nil {
return nil, err
}
return &signature, nil
}

View File

@@ -0,0 +1,277 @@
package reposign
import (
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseSignature_Valid(t *testing.T) {
timestamp := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
keyID, err := ParseKeyID("0123456789abcdef")
require.NoError(t, err)
signatureData := []byte{0x01, 0x02, 0x03, 0x04}
jsonData, err := json.Marshal(Signature{
Signature: signatureData,
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
})
require.NoError(t, err)
sig, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.NotNil(t, sig)
assert.Equal(t, signatureData, sig.Signature)
assert.Equal(t, timestamp.Unix(), sig.Timestamp.Unix())
assert.Equal(t, keyID, sig.KeyID)
assert.Equal(t, "ed25519", sig.Algorithm)
assert.Equal(t, "blake2s", sig.HashAlgo)
}
func TestParseSignature_InvalidJSON(t *testing.T) {
invalidJSON := []byte(`{invalid json}`)
sig, err := ParseSignature(invalidJSON)
assert.Error(t, err)
assert.Nil(t, sig)
}
func TestParseSignature_EmptyData(t *testing.T) {
emptyJSON := []byte(`{}`)
sig, err := ParseSignature(emptyJSON)
require.NoError(t, err)
assert.NotNil(t, sig)
assert.Empty(t, sig.Signature)
assert.True(t, sig.Timestamp.IsZero())
assert.Equal(t, KeyID{}, sig.KeyID)
assert.Empty(t, sig.Algorithm)
assert.Empty(t, sig.HashAlgo)
}
func TestParseSignature_MissingFields(t *testing.T) {
// JSON with only some fields
partialJSON := []byte(`{
"signature": "AQIDBA==",
"algorithm": "ed25519"
}`)
sig, err := ParseSignature(partialJSON)
require.NoError(t, err)
assert.NotNil(t, sig)
assert.NotEmpty(t, sig.Signature)
assert.Equal(t, "ed25519", sig.Algorithm)
assert.True(t, sig.Timestamp.IsZero())
}
func TestSignature_MarshalUnmarshal_Roundtrip(t *testing.T) {
timestamp := time.Date(2024, 6, 20, 14, 45, 30, 0, time.UTC)
keyID, err := ParseKeyID("fedcba9876543210")
require.NoError(t, err)
original := Signature{
Signature: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe},
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "sha512",
}
// Marshal
jsonData, err := json.Marshal(original)
require.NoError(t, err)
// Unmarshal
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
// Verify
assert.Equal(t, original.Signature, parsed.Signature)
assert.Equal(t, original.Timestamp.Unix(), parsed.Timestamp.Unix())
assert.Equal(t, original.KeyID, parsed.KeyID)
assert.Equal(t, original.Algorithm, parsed.Algorithm)
assert.Equal(t, original.HashAlgo, parsed.HashAlgo)
}
func TestSignature_NilSignatureBytes(t *testing.T) {
timestamp := time.Now().UTC()
keyID, err := ParseKeyID("0011223344556677")
require.NoError(t, err)
sig := Signature{
Signature: nil,
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.Nil(t, parsed.Signature)
}
func TestSignature_LargeSignature(t *testing.T) {
timestamp := time.Now().UTC()
keyID, err := ParseKeyID("aabbccddeeff0011")
require.NoError(t, err)
// Create a large signature (64 bytes for ed25519)
largeSignature := make([]byte, 64)
for i := range largeSignature {
largeSignature[i] = byte(i)
}
sig := Signature{
Signature: largeSignature,
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.Equal(t, largeSignature, parsed.Signature)
}
func TestSignature_WithDifferentHashAlgorithms(t *testing.T) {
tests := []struct {
name string
hashAlgo string
}{
{"blake2s", "blake2s"},
{"sha512", "sha512"},
{"sha256", "sha256"},
{"empty", ""},
}
keyID, err := ParseKeyID("1122334455667788")
require.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sig := Signature{
Signature: []byte{0x01, 0x02},
Timestamp: time.Now().UTC(),
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: tt.hashAlgo,
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.Equal(t, tt.hashAlgo, parsed.HashAlgo)
})
}
}
func TestSignature_TimestampPrecision(t *testing.T) {
// Test that timestamp preserves precision through JSON marshaling
timestamp := time.Date(2024, 3, 15, 10, 30, 45, 123456789, time.UTC)
keyID, err := ParseKeyID("8877665544332211")
require.NoError(t, err)
sig := Signature{
Signature: []byte{0xaa, 0xbb},
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
// JSON timestamps typically have second or millisecond precision
// so we check that at least seconds match
assert.Equal(t, timestamp.Unix(), parsed.Timestamp.Unix())
}
func TestParseSignature_MalformedKeyID(t *testing.T) {
// Test with a malformed KeyID field
malformedJSON := []byte(`{
"signature": "AQID",
"timestamp": "2024-01-15T10:30:00Z",
"key_id": "invalid_keyid_format",
"algorithm": "ed25519",
"hash_algo": "blake2s"
}`)
// This should fail since "invalid_keyid_format" is not a valid KeyID
sig, err := ParseSignature(malformedJSON)
assert.Error(t, err)
assert.Nil(t, sig)
}
func TestParseSignature_InvalidTimestamp(t *testing.T) {
// Test with an invalid timestamp format
invalidTimestampJSON := []byte(`{
"signature": "AQID",
"timestamp": "not-a-timestamp",
"key_id": "0123456789abcdef",
"algorithm": "ed25519",
"hash_algo": "blake2s"
}`)
sig, err := ParseSignature(invalidTimestampJSON)
assert.Error(t, err)
assert.Nil(t, sig)
}
func TestSignature_ZeroKeyID(t *testing.T) {
// Test with a zero KeyID
sig := Signature{
Signature: []byte{0x01, 0x02, 0x03},
Timestamp: time.Now().UTC(),
KeyID: KeyID{},
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.Equal(t, KeyID{}, parsed.KeyID)
}
func TestParseSignature_ExtraFields(t *testing.T) {
// JSON with extra fields that should be ignored
jsonWithExtra := []byte(`{
"signature": "AQIDBA==",
"timestamp": "2024-01-15T10:30:00Z",
"key_id": "0123456789abcdef",
"algorithm": "ed25519",
"hash_algo": "blake2s",
"extra_field": "should be ignored",
"another_extra": 12345
}`)
sig, err := ParseSignature(jsonWithExtra)
require.NoError(t, err)
assert.NotNil(t, sig)
assert.NotEmpty(t, sig.Signature)
assert.Equal(t, "ed25519", sig.Algorithm)
assert.Equal(t, "blake2s", sig.HashAlgo)
}

View File

@@ -0,0 +1,187 @@
package reposign
import (
"context"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
)
const (
artifactPubKeysFileName = "artifact-key-pub.pem"
artifactPubKeysSigFileName = "artifact-key-pub.pem.sig"
revocationFileName = "revocation-list.json"
revocationSignFileName = "revocation-list.json.sig"
keySizeLimit = 5 * 1024 * 1024 //5MB
signatureLimit = 1024
revocationLimit = 10 * 1024 * 1024
)
type ArtifactVerify struct {
rootKeys []PublicKey
keysBaseURL *url.URL
revocationList *RevocationList
}
func NewArtifactVerify(keysBaseURL string) (*ArtifactVerify, error) {
allKeys, err := loadEmbeddedPublicKeys()
if err != nil {
return nil, err
}
return newArtifactVerify(keysBaseURL, allKeys)
}
func newArtifactVerify(keysBaseURL string, allKeys []PublicKey) (*ArtifactVerify, error) {
ku, err := url.Parse(keysBaseURL)
if err != nil {
return nil, fmt.Errorf("invalid keys base URL %q: %v", keysBaseURL, err)
}
a := &ArtifactVerify{
rootKeys: allKeys,
keysBaseURL: ku,
}
return a, nil
}
func (a *ArtifactVerify) Verify(ctx context.Context, version string, artifactFile string) error {
version = strings.TrimPrefix(version, "v")
revocationList, err := a.loadRevocationList(ctx)
if err != nil {
return fmt.Errorf("failed to load revocation list: %v", err)
}
a.revocationList = revocationList
artifactPubKeys, err := a.loadArtifactKeys(ctx)
if err != nil {
return fmt.Errorf("failed to load artifact keys: %v", err)
}
signature, err := a.loadArtifactSignature(ctx, version, artifactFile)
if err != nil {
return fmt.Errorf("failed to download signature file for: %s, %v", filepath.Base(artifactFile), err)
}
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
log.Errorf("failed to read artifact file: %v", err)
return fmt.Errorf("failed to read artifact file: %w", err)
}
if err := ValidateArtifact(artifactPubKeys, artifactData, *signature); err != nil {
return fmt.Errorf("failed to validate artifact: %v", err)
}
return nil
}
func (a *ArtifactVerify) loadRevocationList(ctx context.Context) (*RevocationList, error) {
downloadURL := a.keysBaseURL.JoinPath("keys", revocationFileName).String()
data, err := downloader.DownloadToMemory(ctx, downloadURL, revocationLimit)
if err != nil {
log.Debugf("failed to download revocation list '%s': %s", downloadURL, err)
return nil, err
}
downloadURL = a.keysBaseURL.JoinPath("keys", revocationSignFileName).String()
sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
if err != nil {
log.Debugf("failed to download revocation list '%s': %s", downloadURL, err)
return nil, err
}
signature, err := ParseSignature(sigData)
if err != nil {
log.Debugf("failed to parse revocation list signature: %s", err)
return nil, err
}
return ValidateRevocationList(a.rootKeys, data, *signature)
}
func (a *ArtifactVerify) loadArtifactKeys(ctx context.Context) ([]PublicKey, error) {
downloadURL := a.keysBaseURL.JoinPath("keys", artifactPubKeysFileName).String()
log.Debugf("starting downloading artifact keys from: %s", downloadURL)
data, err := downloader.DownloadToMemory(ctx, downloadURL, keySizeLimit)
if err != nil {
log.Debugf("failed to download artifact keys: %s", err)
return nil, err
}
downloadURL = a.keysBaseURL.JoinPath("keys", artifactPubKeysSigFileName).String()
log.Debugf("start downloading signature of artifact pub key from: %s", downloadURL)
sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
if err != nil {
log.Debugf("failed to download signature of public keys: %s", err)
return nil, err
}
signature, err := ParseSignature(sigData)
if err != nil {
log.Debugf("failed to parse signature of public keys: %s", err)
return nil, err
}
return ValidateArtifactKeys(a.rootKeys, data, *signature, a.revocationList)
}
func (a *ArtifactVerify) loadArtifactSignature(ctx context.Context, version string, artifactFile string) (*Signature, error) {
artifactFile = filepath.Base(artifactFile)
downloadURL := a.keysBaseURL.JoinPath("tag", "v"+version, artifactFile+".sig").String()
data, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
if err != nil {
log.Debugf("failed to download artifact signature: %s", err)
return nil, err
}
signature, err := ParseSignature(data)
if err != nil {
log.Debugf("failed to parse artifact signature: %s", err)
return nil, err
}
return signature, nil
}
func loadEmbeddedPublicKeys() ([]PublicKey, error) {
files, err := embeddedCerts.ReadDir(embeddedCertsDir)
if err != nil {
return nil, fmt.Errorf("failed to read embedded certs: %w", err)
}
var allKeys []PublicKey
for _, file := range files {
if file.IsDir() {
continue
}
data, err := embeddedCerts.ReadFile(embeddedCertsDir + "/" + file.Name())
if err != nil {
return nil, fmt.Errorf("failed to read cert file %s: %w", file.Name(), err)
}
keys, err := parsePublicKeyBundle(data, tagRootPublic)
if err != nil {
return nil, fmt.Errorf("failed to parse cert %s: %w", file.Name(), err)
}
allKeys = append(allKeys, keys...)
}
if len(allKeys) == 0 {
return nil, fmt.Errorf("no valid public keys found in embedded certs")
}
return allKeys, nil
}

View File

@@ -0,0 +1,528 @@
package reposign
import (
"context"
"crypto/ed25519"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test ArtifactVerify construction
func TestArtifactVerify_Construction(t *testing.T) {
// Generate test root key
rootKey, _, rootPubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey, _, err := parsePublicKey(rootPubPEM, tagRootPublic)
require.NoError(t, err)
keysBaseURL := "http://localhost:8080/artifact-signatures"
av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey})
require.NoError(t, err)
assert.NotNil(t, av)
assert.NotEmpty(t, av.rootKeys)
assert.Equal(t, keysBaseURL, av.keysBaseURL.String())
// Verify root key structure
assert.NotEmpty(t, av.rootKeys[0].Key)
assert.Equal(t, rootKey.Metadata.ID, av.rootKeys[0].Metadata.ID)
assert.False(t, av.rootKeys[0].Metadata.CreatedAt.IsZero())
}
func TestArtifactVerify_MultipleRootKeys(t *testing.T) {
// Generate multiple test root keys
rootKey1, _, rootPubPEM1, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey1, _, err := parsePublicKey(rootPubPEM1, tagRootPublic)
require.NoError(t, err)
rootKey2, _, rootPubPEM2, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey2, _, err := parsePublicKey(rootPubPEM2, tagRootPublic)
require.NoError(t, err)
keysBaseURL := "http://localhost:8080/artifact-signatures"
av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey1, rootPubKey2})
assert.NoError(t, err)
assert.Len(t, av.rootKeys, 2)
assert.NotEqual(t, rootKey1.Metadata.ID, rootKey2.Metadata.ID)
}
// Test Verify workflow with mock HTTP server
func TestArtifactVerify_FullWorkflow(t *testing.T) {
// Create temporary test directory
tempDir := t.TempDir()
// Step 1: Generate root key
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
// Step 2: Generate artifact key
artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
require.NoError(t, err)
// Step 3: Create revocation list
revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
// Step 4: Bundle artifact keys
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
require.NoError(t, err)
// Step 5: Create test artifact
artifactPath := filepath.Join(tempDir, "test-artifact.bin")
artifactData := []byte("This is test artifact data for verification")
err = os.WriteFile(artifactPath, artifactData, 0644)
require.NoError(t, err)
// Step 6: Sign artifact
artifactSigData, err := SignData(*artifactKey, artifactData)
require.NoError(t, err)
// Step 7: Setup mock HTTP server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifacts/v1.0.0/test-artifact.bin":
_, _ = w.Write(artifactData)
case "/artifact-signatures/tag/v1.0.0/test-artifact.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
// Step 8: Create ArtifactVerify with test root key
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
// Step 9: Verify artifact
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
assert.NoError(t, err)
}
func TestArtifactVerify_InvalidRevocationList(t *testing.T) {
tempDir := t.TempDir()
artifactPath := filepath.Join(tempDir, "test.bin")
err := os.WriteFile(artifactPath, []byte("test"), 0644)
require.NoError(t, err)
// Setup server with invalid revocation list
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write([]byte("invalid data"))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to load revocation list")
}
func TestArtifactVerify_MissingArtifactFile(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
// Create revocation list
revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
require.NoError(t, err)
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
require.NoError(t, err)
// Create signature for non-existent file
testData := []byte("test")
artifactSigData, err := SignData(*artifactKey, testData)
require.NoError(t, err)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifact-signatures/tag/v1.0.0/missing.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", "file.bin")
assert.Error(t, err)
}
func TestArtifactVerify_ServerUnavailable(t *testing.T) {
tempDir := t.TempDir()
artifactPath := filepath.Join(tempDir, "test.bin")
err := os.WriteFile(artifactPath, []byte("test"), 0644)
require.NoError(t, err)
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
// Use URL that doesn't exist
av, err := newArtifactVerify("http://localhost:19999/keys", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err = av.Verify(ctx, "1.0.0", artifactPath)
assert.Error(t, err)
}
func TestArtifactVerify_ContextCancellation(t *testing.T) {
tempDir := t.TempDir()
artifactPath := filepath.Join(tempDir, "test.bin")
err := os.WriteFile(artifactPath, []byte("test"), 0644)
require.NoError(t, err)
// Create a server that delays response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(500 * time.Millisecond)
_, _ = w.Write([]byte("data"))
}))
defer server.Close()
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL, []PublicKey{rootPubKey})
require.NoError(t, err)
// Create context that cancels quickly
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err = av.Verify(ctx, "1.0.0", artifactPath)
assert.Error(t, err)
}
func TestArtifactVerify_WithRevocation(t *testing.T) {
tempDir := t.TempDir()
// Generate root key
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
// Generate two artifact keys
artifactKey1, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1)
require.NoError(t, err)
_, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2)
require.NoError(t, err)
// Create revocation list with first key revoked
emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
parsedRevocation, err := ParseRevocationList(emptyRevocation)
require.NoError(t, err)
revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration)
require.NoError(t, err)
// Bundle both keys
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2})
require.NoError(t, err)
// Create artifact signed by revoked key
artifactPath := filepath.Join(tempDir, "test.bin")
artifactData := []byte("test data")
err = os.WriteFile(artifactPath, artifactData, 0644)
require.NoError(t, err)
artifactSigData, err := SignData(*artifactKey1, artifactData)
require.NoError(t, err)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
// Should fail because the signing key is revoked
assert.Error(t, err)
assert.Contains(t, err.Error(), "no signing Key found")
}
func TestArtifactVerify_ValidWithSecondKey(t *testing.T) {
tempDir := t.TempDir()
// Generate root key
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
// Generate two artifact keys
_, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1)
require.NoError(t, err)
artifactKey2, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2)
require.NoError(t, err)
// Create revocation list with first key revoked
emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
parsedRevocation, err := ParseRevocationList(emptyRevocation)
require.NoError(t, err)
revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration)
require.NoError(t, err)
// Bundle both keys
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2})
require.NoError(t, err)
// Create artifact signed by second key (not revoked)
artifactPath := filepath.Join(tempDir, "test.bin")
artifactData := []byte("test data")
err = os.WriteFile(artifactPath, artifactData, 0644)
require.NoError(t, err)
artifactSigData, err := SignData(*artifactKey2, artifactData)
require.NoError(t, err)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
// Should succeed because second key is not revoked
assert.NoError(t, err)
}
func TestArtifactVerify_TamperedArtifact(t *testing.T) {
tempDir := t.TempDir()
// Generate root key and artifact key
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
require.NoError(t, err)
// Create revocation list
revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
// Bundle keys
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
require.NoError(t, err)
// Sign original data
originalData := []byte("original data")
artifactSigData, err := SignData(*artifactKey, originalData)
require.NoError(t, err)
// Write tampered data to file
artifactPath := filepath.Join(tempDir, "test.bin")
tamperedData := []byte("tampered data")
err = os.WriteFile(artifactPath, tamperedData, 0644)
require.NoError(t, err)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
// Should fail because artifact was tampered
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to validate artifact")
}
// Test URL validation
func TestArtifactVerify_URLParsing(t *testing.T) {
tests := []struct {
name string
keysBaseURL string
expectError bool
}{
{
name: "Valid HTTP URL",
keysBaseURL: "http://example.com/artifact-signatures",
expectError: false,
},
{
name: "Valid HTTPS URL",
keysBaseURL: "https://example.com/artifact-signatures",
expectError: false,
},
{
name: "URL with port",
keysBaseURL: "http://localhost:8080/artifact-signatures",
expectError: false,
},
{
name: "Invalid URL",
keysBaseURL: "://invalid",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := newArtifactVerify(tt.keysBaseURL, nil)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,11 @@
package updatemanager
import v "github.com/hashicorp/go-version"
type UpdateInterface interface {
StopWatch()
SetDaemonVersion(newVersion string) bool
SetOnUpdateListener(updateFn func())
LatestVersion() *v.Version
StartFetcher()
}

View File

@@ -75,6 +75,8 @@ type Client struct {
dnsManager dns.IosDnsManager
loginComplete bool
connectClient *internal.ConnectClient
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
preloadedConfig *profilemanager.Config
}
// NewClient instantiate a new Client
@@ -92,17 +94,44 @@ func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName s
}
}
// SetConfigFromJSON loads config from a JSON string into memory.
// This is used on tvOS where file writes to App Group containers are blocked.
// When set, IsLoginRequired() and Run() will use this preloaded config instead of reading from file.
func (c *Client) SetConfigFromJSON(jsonStr string) error {
cfg, err := profilemanager.ConfigFromJSON(jsonStr)
if err != nil {
log.Errorf("SetConfigFromJSON: failed to parse config JSON: %v", err)
return err
}
c.preloadedConfig = cfg
log.Infof("SetConfigFromJSON: config loaded successfully from JSON")
return nil
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
exportEnvList(envList)
log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
StateFilePath: c.stateFile,
})
if err != nil {
return err
var cfg *profilemanager.Config
var err error
// Use preloaded config if available (tvOS where file writes are blocked)
if c.preloadedConfig != nil {
log.Infof("Run: using preloaded config from memory")
cfg = c.preloadedConfig
} else {
log.Infof("Run: loading config from file")
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
StateFilePath: c.stateFile,
})
if err != nil {
return err
}
}
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
@@ -120,7 +149,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.ctxCancelLock.Unlock()
auth := NewAuthWithConfig(ctx, cfg)
err = auth.Login()
err = auth.LoginSync()
if err != nil {
return err
}
@@ -131,7 +160,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}
@@ -208,14 +237,45 @@ func (c *Client) IsLoginRequired() bool {
defer c.ctxCancelLock.Unlock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
var cfg *profilemanager.Config
var err error
needsLogin, _ := internal.IsLoginRequired(ctx, cfg)
// Use preloaded config if available (tvOS where file writes are blocked)
if c.preloadedConfig != nil {
log.Infof("IsLoginRequired: using preloaded config from memory")
cfg = c.preloadedConfig
} else {
log.Infof("IsLoginRequired: loading config from file")
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
log.Errorf("IsLoginRequired: failed to load config: %v", err)
// If we can't load config, assume login is required
return true
}
}
if cfg == nil {
log.Errorf("IsLoginRequired: config is nil")
return true
}
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
if err != nil {
log.Errorf("IsLoginRequired: check failed: %v", err)
// If the check fails, assume login is required to be safe
return true
}
log.Infof("IsLoginRequired: needsLogin=%v", needsLogin)
return needsLogin
}
// loginForMobileAuthTimeout is the timeout for requesting auth info from the server
const loginForMobileAuthTimeout = 30 * time.Second
func (c *Client) LoginForMobile() string {
var ctx context.Context
//nolint
@@ -228,16 +288,26 @@ func (c *Client) LoginForMobile() string {
defer c.ctxCancelLock.Unlock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
cfg, _ := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
cfg, err := profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
if err != nil {
log.Errorf("LoginForMobile: failed to load config: %v", err)
return fmt.Sprintf("failed to load config: %v", err)
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, false, "")
if err != nil {
return err.Error()
}
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
// Use a bounded timeout for the auth info request to prevent indefinite hangs
authInfoCtx, authInfoCancel := context.WithTimeout(ctx, loginForMobileAuthTimeout)
defer authInfoCancel()
flowInfo, err := oAuthFlow.RequestAuthInfo(authInfoCtx)
if err != nil {
return err.Error()
}
@@ -249,10 +319,14 @@ func (c *Client) LoginForMobile() string {
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
return
}
jwtToken := tokenInfo.GetTokenToUse()
_ = internal.Login(ctx, cfg, "", jwtToken)
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
log.Errorf("LoginForMobile: Login failed: %v", err)
return
}
c.loginComplete = true
}()

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
)
@@ -33,7 +34,8 @@ type ErrListener interface {
// URLOpener it is a callback interface. The Open function will be triggered if
// the backend want to show an url for the user
type URLOpener interface {
Open(string)
Open(url string, userCode string)
OnLoginSuccess()
}
// Auth can register or login new client
@@ -72,13 +74,32 @@ func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
// is not supported and returns false without saving the configuration. For other errors return false.
func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
if listener == nil {
log.Errorf("SaveConfigIfSSOSupported: listener is nil")
return
}
go func() {
sso, err := a.saveConfigIfSSOSupported()
if err != nil {
listener.OnError(err)
} else {
listener.OnSuccess(sso)
}
}()
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
@@ -97,12 +118,29 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
return true, err
}
// LoginWithSetupKeyAndSaveConfig test the connectivity with the management server with the setup key.
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupKey string, deviceName string) {
if resultListener == nil {
log.Errorf("LoginWithSetupKeyAndSaveConfig: resultListener is nil")
return
}
go func() {
err := a.loginWithSetupKeyAndSaveConfig(setupKey, deviceName)
if err != nil {
resultListener.OnError(err)
} else {
resultListener.OnSuccess()
}
}()
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
@@ -118,10 +156,14 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
return fmt.Errorf("backoff cycle failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
return profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
}
func (a *Auth) Login() error {
// LoginSync performs a synchronous login check without UI interaction
// Used for background VPN connection where user should already be authenticated
func (a *Auth) LoginSync() error {
var needsLogin bool
// check if we need to generate JWT token
@@ -135,23 +177,142 @@ func (a *Auth) Login() error {
jwtToken := ""
if needsLogin {
return fmt.Errorf("Not authenticated")
return fmt.Errorf("not authenticated")
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// PermissionDenied means registration is required or peer is blocked
return backoff.Permanent(err)
}
return err
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
return nil
}
// Login performs interactive login with device authentication support
// Deprecated: Use LoginWithDeviceName instead to ensure proper device naming on tvOS
func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, forceDeviceAuth bool) {
// Use empty device name - system will use hostname as fallback
a.LoginWithDeviceName(resultListener, urlOpener, forceDeviceAuth, "")
}
// LoginWithDeviceName performs interactive login with device authentication support
// The deviceName parameter allows specifying a custom device name (required for tvOS)
func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpener, forceDeviceAuth bool, deviceName string) {
if resultListener == nil {
log.Errorf("LoginWithDeviceName: resultListener is nil")
return
}
if urlOpener == nil {
log.Errorf("LoginWithDeviceName: urlOpener is nil")
resultListener.OnError(fmt.Errorf("urlOpener is nil"))
return
}
go func() {
err := a.login(urlOpener, forceDeviceAuth, deviceName)
if err != nil {
resultListener.OnError(err)
} else {
resultListener.OnSuccess()
}
}()
}
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
var needsLogin bool
// Create context with device name if provided
ctx := a.ctx
if deviceName != "" {
//nolint:staticcheck
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
}
// check if we need to generate JWT token
err := a.withBackOff(ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
return
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err = a.withBackOff(ctx, func() error {
err := internal.Login(ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// PermissionDenied means registration is required or peer is blocked
return backoff.Permanent(err)
}
return err
})
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
// Save the config before notifying success to ensure persistence completes
// before the callback potentially triggers teardown on the Swift side.
// Note: This differs from Android which doesn't save config after login.
// On iOS/tvOS, we save here because:
// 1. The config may have been modified during login (e.g., new tokens)
// 2. On tvOS, the Network Extension context may be the only place with
// write permissions to the App Group container
if a.cfgPath != "" {
if err := profilemanager.DirectWriteOutConfig(a.cfgPath, a.config); err != nil {
log.Warnf("failed to save config after login: %v", err)
}
}
// Notify caller of successful login synchronously before returning
urlOpener.OnLoginSuccess()
return nil
}
const authInfoRequestTimeout = 30 * time.Second
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
if err != nil {
return nil, err
}
// Use a bounded timeout for the auth info request to prevent indefinite hangs
authInfoCtx, authInfoCancel := context.WithTimeout(a.ctx, authInfoRequestTimeout)
defer authInfoCancel()
flowInfo, err := oAuthFlow.RequestAuthInfo(authInfoCtx)
if err != nil {
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
}
urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
@@ -160,3 +321,24 @@ func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}
// GetConfigJSON returns the current config as a JSON string.
// This can be used by the caller to persist the config via alternative storage
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
func (a *Auth) GetConfigJSON() (string, error) {
if a.config == nil {
return "", fmt.Errorf("no config available")
}
return profilemanager.ConfigToJSON(a.config)
}
// SetConfigFromJSON loads config from a JSON string.
// This can be used to restore config from alternative storage mechanisms.
func (a *Auth) SetConfigFromJSON(jsonStr string) error {
cfg, err := profilemanager.ConfigFromJSON(jsonStr)
if err != nil {
return err
}
a.config = cfg
return nil
}

View File

@@ -112,6 +112,8 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
// Commit write out the changes into config file
func (p *Preferences) Commit() error {
_, err := profilemanager.UpdateOrCreateConfig(p.configInput)
// Use DirectUpdateOrCreateConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
_, err := profilemanager.DirectUpdateOrCreateConfig(p.configInput)
return err
}

View File

@@ -51,7 +51,7 @@
</ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" />
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" TerminateProcess="0" />

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v6.32.1
// protoc v3.21.12
// source: daemon.proto
package proto
@@ -893,6 +893,7 @@ type UpRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
AutoUpdate *bool `protobuf:"varint,3,opt,name=autoUpdate,proto3,oneof" json:"autoUpdate,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -941,6 +942,13 @@ func (x *UpRequest) GetUsername() string {
return ""
}
func (x *UpRequest) GetAutoUpdate() bool {
if x != nil && x.AutoUpdate != nil {
return *x.AutoUpdate
}
return false
}
type UpResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -2005,6 +2013,7 @@ type SSHSessionInfo struct {
RemoteAddress string `protobuf:"bytes,2,opt,name=remoteAddress,proto3" json:"remoteAddress,omitempty"`
Command string `protobuf:"bytes,3,opt,name=command,proto3" json:"command,omitempty"`
JwtUsername string `protobuf:"bytes,4,opt,name=jwtUsername,proto3" json:"jwtUsername,omitempty"`
PortForwards []string `protobuf:"bytes,5,rep,name=portForwards,proto3" json:"portForwards,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -2067,6 +2076,13 @@ func (x *SSHSessionInfo) GetJwtUsername() string {
return ""
}
func (x *SSHSessionInfo) GetPortForwards() []string {
if x != nil {
return x.PortForwards
}
return nil
}
// SSHServerState contains the latest state of the SSH server
type SSHServerState struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -5356,6 +5372,94 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
return 0
}
type InstallerResultRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *InstallerResultRequest) Reset() {
*x = InstallerResultRequest{}
mi := &file_daemon_proto_msgTypes[79]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *InstallerResultRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*InstallerResultRequest) ProtoMessage() {}
func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[79]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{79}
}
type InstallerResultResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
ErrorMsg string `protobuf:"bytes,2,opt,name=errorMsg,proto3" json:"errorMsg,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *InstallerResultResponse) Reset() {
*x = InstallerResultResponse{}
mi := &file_daemon_proto_msgTypes[80]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *InstallerResultResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*InstallerResultResponse) ProtoMessage() {}
func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[80]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{80}
}
func (x *InstallerResultResponse) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
func (x *InstallerResultResponse) GetErrorMsg() string {
if x != nil {
return x.ErrorMsg
}
return ""
}
type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -5366,7 +5470,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[80]
mi := &file_daemon_proto_msgTypes[82]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5378,7 +5482,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[80]
mi := &file_daemon_proto_msgTypes[82]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5502,12 +5606,16 @@ const file_daemon_proto_rawDesc = "" +
"\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" +
"\bhostname\x18\x02 \x01(\tR\bhostname\",\n" +
"\x14WaitSSOLoginResponse\x12\x14\n" +
"\x05email\x18\x01 \x01(\tR\x05email\"p\n" +
"\x05email\x18\x01 \x01(\tR\x05email\"\xa4\x01\n" +
"\tUpRequest\x12%\n" +
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12#\n" +
"\n" +
"autoUpdate\x18\x03 \x01(\bH\x02R\n" +
"autoUpdate\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" +
"\t_username\"\f\n" +
"\t_usernameB\r\n" +
"\v_autoUpdate\"\f\n" +
"\n" +
"UpResponse\"\xa1\x01\n" +
"\rStatusRequest\x12,\n" +
@@ -5606,12 +5714,13 @@ const file_daemon_proto_rawDesc = "" +
"\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" +
"\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" +
"\aenabled\x18\x03 \x01(\bR\aenabled\x12\x14\n" +
"\x05error\x18\x04 \x01(\tR\x05error\"\x8e\x01\n" +
"\x05error\x18\x04 \x01(\tR\x05error\"\xb2\x01\n" +
"\x0eSSHSessionInfo\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12$\n" +
"\rremoteAddress\x18\x02 \x01(\tR\rremoteAddress\x12\x18\n" +
"\acommand\x18\x03 \x01(\tR\acommand\x12 \n" +
"\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\"^\n" +
"\vjwtUsername\x18\x04 \x01(\tR\vjwtUsername\x12\"\n" +
"\fportForwards\x18\x05 \x03(\tR\fportForwards\"^\n" +
"\x0eSSHServerState\x12\x18\n" +
"\aenabled\x18\x01 \x01(\bR\aenabled\x122\n" +
"\bsessions\x18\x02 \x03(\v2\x16.daemon.SSHSessionInfoR\bsessions\"\xaf\x04\n" +
@@ -5893,7 +6002,11 @@ const file_daemon_proto_rawDesc = "" +
"\x14WaitJWTTokenResponse\x12\x14\n" +
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" +
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" +
"\x16InstallerResultRequest\"O\n" +
"\x17InstallerResultResponse\x12\x18\n" +
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" +
"\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" +
@@ -5902,7 +6015,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xdb\x12\n" +
"\x05TRACE\x10\a2\xb4\x13\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -5938,7 +6051,8 @@ const file_daemon_proto_rawDesc = "" +
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" +
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00B\bZ\x06/protob\x06proto3"
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
var (
file_daemon_proto_rawDescOnce sync.Once
@@ -5953,7 +6067,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 82)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
@@ -6038,19 +6152,21 @@ var file_daemon_proto_goTypes = []any{
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
nil, // 83: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 84: daemon.PortInfo.Range
nil, // 85: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 86: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 87: google.protobuf.Timestamp
(*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse
nil, // 85: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range
nil, // 87: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 88: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
86, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
87, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
87, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
86, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
@@ -6061,8 +6177,8 @@ var file_daemon_proto_depIdxs = []int32{
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
83, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
84, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -6073,10 +6189,10 @@ var file_daemon_proto_depIdxs = []int32{
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
87, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
85, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
86, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
@@ -6111,40 +6227,42 @@ var file_daemon_proto_depIdxs = []int32{
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
8, // 66: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 67: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 68: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 69: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 70: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 71: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 72: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 73: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 74: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 75: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 76: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 77: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 78: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 79: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 80: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 81: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 82: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 83: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 84: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 85: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 86: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 87: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 88: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 89: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 90: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 91: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 92: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 93: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 94: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 95: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 96: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
6, // 97: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
66, // [66:98] is the sub-list for method output_type
34, // [34:66] is the sub-list for method input_type
83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
67, // [67:100] is the sub-list for method output_type
34, // [34:67] is the sub-list for method input_type
34, // [34:34] is the sub-list for extension type_name
34, // [34:34] is the sub-list for extension extendee
0, // [0:34] is the sub-list for field type_name
@@ -6174,7 +6292,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 4,
NumMessages: 82,
NumMessages: 84,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -95,6 +95,8 @@ service DaemonService {
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
}
@@ -215,6 +217,7 @@ message WaitSSOLoginResponse {
message UpRequest {
optional string profileName = 1;
optional string username = 2;
optional bool autoUpdate = 3;
}
message UpResponse {}
@@ -369,6 +372,7 @@ message SSHSessionInfo {
string remoteAddress = 2;
string command = 3;
string jwtUsername = 4;
repeated string portForwards = 5;
}
// SSHServerState contains the latest state of the SSH server
@@ -772,3 +776,11 @@ message WaitJWTTokenResponse {
// expiration time in seconds
int64 expiresIn = 3;
}
message InstallerResultRequest {
}
message InstallerResultResponse {
bool success = 1;
string errorMsg = 2;
}

View File

@@ -71,6 +71,7 @@ type DaemonServiceClient interface {
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
}
type daemonServiceClient struct {
@@ -392,6 +393,15 @@ func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifec
return out, nil
}
func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) {
out := new(InstallerResultResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetInstallerResult", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -449,6 +459,7 @@ type DaemonServiceServer interface {
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -552,6 +563,9 @@ func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTo
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
}
func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1144,6 +1158,24 @@ func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Conte
return interceptor(ctx, in, info, handler)
}
func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InstallerResultRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).GetInstallerResult(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/GetInstallerResult",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetInstallerResult(ctx, req.(*InstallerResultRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1275,6 +1307,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "NotifyOSLifecycle",
Handler: _DaemonService_NotifyOSLifecycle_Handler,
},
{
MethodName: "GetInstallerResult",
Handler: _DaemonService_GetInstallerResult_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -14,4 +14,4 @@ cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
cd "$old_pwd"
cd "$old_pwd"

View File

@@ -145,10 +145,10 @@ func (s *Server) Start() error {
ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
// set the default config if not exists
if err := s.setDefaultConfigIfNotExists(ctx); err != nil {
log.Errorf("failed to set default config: %v", err)
return fmt.Errorf("failed to set default config: %w", err)
// copy old default config
_, err = s.profileManager.CopyDefaultProfileIfNotExists()
if err != nil && !errors.Is(err, profilemanager.ErrorOldDefaultConfigNotFound) {
return err
}
activeProf, err := s.profileManager.GetActiveProfileState()
@@ -156,23 +156,11 @@ func (s *Server) Start() error {
return fmt.Errorf("failed to get active profile state: %w", err)
}
config, err := s.getConfig(activeProf)
config, existingConfig, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "default",
Username: "",
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return fmt.Errorf("failed to set active profile state: %w", err)
}
config, err = profilemanager.GetConfig(s.profileManager.DefaultProfilePath())
if err != nil {
log.Errorf("failed to get default profile config: %v", err)
return fmt.Errorf("failed to get default profile config: %w", err)
}
return err
}
s.config = config
@@ -186,44 +174,27 @@ func (s *Server) Start() error {
}
if config.DisableAutoConnect {
state.Set(internal.StatusIdle)
return nil
}
if !existingConfig {
log.Warnf("not trying to connect when configuration was just created")
state.Set(internal.StatusNeedsLogin)
return nil
}
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
return nil
}
func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
ok, err := s.profileManager.CopyDefaultProfileIfNotExists()
if err != nil {
if err := s.profileManager.CreateDefaultProfile(); err != nil {
log.Errorf("failed to create default profile: %v", err)
return fmt.Errorf("failed to create default profile: %w", err)
}
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "default",
Username: "",
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return fmt.Errorf("failed to set active profile state: %w", err)
}
}
if ok {
state := internal.CtxGetState(ctx)
state.Set(internal.StatusNeedsLogin)
}
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan)
return nil
}
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() {
s.mutex.Lock()
s.clientRunning = false
@@ -231,7 +202,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
}()
if s.config.DisableAutoConnect {
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
@@ -260,7 +231,8 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
}()
runOperation := func() error {
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan)
doInitialAutoUpdate = false
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
@@ -486,7 +458,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.mutex.Unlock()
config, err := s.getConfig(activeProf)
config, _, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err)
@@ -715,7 +687,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
config, err := s.getConfig(activeProf)
config, _, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err)
@@ -728,7 +700,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
var doAutoUpdate bool
if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate {
doAutoUpdate = true
}
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
return s.waitForUp(callerCtx)
}
@@ -805,7 +782,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
config, err := s.getConfig(activeProf)
config, _, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get default profile config: %v", err)
return nil, fmt.Errorf("failed to get default profile config: %w", err)
@@ -902,7 +879,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed to get active profile state: %v", err)
}
config, err := s.getConfig(activeProf)
config, _, err := s.getConfig(activeProf)
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "not logged in")
}
@@ -926,19 +903,24 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
return &proto.LogoutResponse{}, nil
}
// getConfig loads the config from the active profile
func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, error) {
// GetConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, bool, error) {
cfgPath, err := activeProf.FilePath()
if err != nil {
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
return nil, false, fmt.Errorf("failed to get active profile file path: %w", err)
}
config, err := profilemanager.GetConfig(cfgPath)
_, err = os.Stat(cfgPath)
configExisted := !os.IsNotExist(err)
log.Infof("active profile config existed: %t, err %v", configExisted, err)
config, err := profilemanager.ReadConfig(cfgPath)
if err != nil {
return nil, fmt.Errorf("failed to get config: %w", err)
return nil, false, fmt.Errorf("failed to get config: %w", err)
}
return config, nil
return config, configExisted, nil
}
func (s *Server) canRemoveProfile(profileName string) error {
@@ -1122,6 +1104,7 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
RemoteAddress: session.RemoteAddress,
Command: session.Command,
JwtUsername: session.JWTUsername,
PortForwards: session.PortForwards,
})
}
@@ -1539,9 +1522,9 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
return features, nil
}
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
return err

View File

@@ -112,7 +112,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
@@ -326,7 +326,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}

View File

@@ -0,0 +1,30 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/client/proto"
)
func (s *Server) GetInstallerResult(ctx context.Context, _ *proto.InstallerResultRequest) (*proto.InstallerResultResponse, error) {
inst := installer.New()
dir := inst.TempDir()
rh := installer.NewResultHandler(dir)
result, err := rh.Watch(ctx)
if err != nil {
log.Errorf("failed to watch update result: %v", err)
return &proto.InstallerResultResponse{
Success: false,
ErrorMsg: err.Error(),
}, nil
}
return &proto.InstallerResultResponse{
Success: result.Success,
ErrorMsg: result.Error,
}, nil
}

177
client/ssh/auth/auth.go Normal file
View File

@@ -0,0 +1,177 @@
package auth
import (
"errors"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
const (
// DefaultUserIDClaim is the default JWT claim used to extract user IDs
DefaultUserIDClaim = "sub"
// Wildcard is a special user ID that matches all users
Wildcard = "*"
)
var (
ErrEmptyUserID = errors.New("JWT user ID is empty")
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
)
// Authorizer handles SSH fine-grained access control authorization
type Authorizer struct {
// UserIDClaim is the JWT claim to extract the user ID from
userIDClaim string
// authorizedUsers is a list of hashed user IDs authorized to access this peer
authorizedUsers []sshuserhash.UserIDHash
// machineUsers maps OS login usernames to lists of authorized user indexes
machineUsers map[string][]uint32
// mu protects the list of users
mu sync.RWMutex
}
// Config contains configuration for the SSH authorizer
type Config struct {
// UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email")
UserIDClaim string
// AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer
AuthorizedUsers []sshuserhash.UserIDHash
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
// If a user wants to login as a specific OS user, their index must be in the corresponding list
MachineUsers map[string][]uint32
}
// NewAuthorizer creates a new SSH authorizer with empty configuration
func NewAuthorizer() *Authorizer {
a := &Authorizer{
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
}
return a
}
// Update updates the authorizer configuration with new values
func (a *Authorizer) Update(config *Config) {
a.mu.Lock()
defer a.mu.Unlock()
if config == nil {
// Clear authorization
a.userIDClaim = DefaultUserIDClaim
a.authorizedUsers = []sshuserhash.UserIDHash{}
a.machineUsers = make(map[string][]uint32)
log.Info("SSH authorization cleared")
return
}
userIDClaim := config.UserIDClaim
if userIDClaim == "" {
userIDClaim = DefaultUserIDClaim
}
a.userIDClaim = userIDClaim
// Store authorized users list
a.authorizedUsers = config.AuthorizedUsers
// Store machine users mapping
machineUsers := make(map[string][]uint32)
for osUser, indexes := range config.MachineUsers {
if len(indexes) > 0 {
machineUsers[osUser] = indexes
}
}
a.machineUsers = machineUsers
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
len(config.AuthorizedUsers), len(machineUsers))
}
// Authorize validates if a user is authorized to login as the specified OS user.
// Returns a success message describing how authorization was granted, or an error.
func (a *Authorizer) Authorize(jwtUserID, osUsername string) (string, error) {
if jwtUserID == "" {
return "", fmt.Errorf("JWT user ID is empty for OS user %q: %w", osUsername, ErrEmptyUserID)
}
// Hash the JWT user ID for comparison
hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
if err != nil {
return "", fmt.Errorf("hash user ID %q for OS user %q: %w", jwtUserID, osUsername, err)
}
a.mu.RLock()
defer a.mu.RUnlock()
// Find the index of this user in the authorized list
userIndex, found := a.findUserIndex(hashedUserID)
if !found {
return "", fmt.Errorf("user %q (hash: %s) not in authorized list for OS user %q: %w", jwtUserID, hashedUserID, osUsername, ErrUserNotAuthorized)
}
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
}
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
// Checks wildcard mapping first, then specific OS user mappings
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) (string, error) {
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
return fmt.Sprintf("granted via wildcard (index: %d)", userIndex), nil
}
}
// Check for specific OS username mapping
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
if !hasMachineUserMapping {
// No mapping for this OS user - deny by default (fail closed)
return "", fmt.Errorf("no machine user mapping for OS user %q (JWT user: %s): %w", osUsername, jwtUserID, ErrNoMachineUserMapping)
}
// Check if user's index is in the allowed indexes for this specific OS user
if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
return "", fmt.Errorf("user %q not mapped to OS user %q (index: %d): %w", jwtUserID, osUsername, userIndex, ErrUserNotMappedToOSUser)
}
return fmt.Sprintf("granted (index: %d)", userIndex), nil
}
// GetUserIDClaim returns the JWT claim name used to extract user IDs
func (a *Authorizer) GetUserIDClaim() string {
a.mu.RLock()
defer a.mu.RUnlock()
return a.userIDClaim
}
// findUserIndex finds the index of a hashed user ID in the authorized users list
// Returns the index and true if found, 0 and false if not found
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
for i, id := range a.authorizedUsers {
if id == hashedUserID {
return i, true
}
}
return 0, false
}
// isIndexInList checks if an index exists in a list of indexes
func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool {
for _, idx := range indexes {
if idx == index {
return true
}
}
return false
}

View File

@@ -0,0 +1,612 @@
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/sshauth"
)
func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list with one user
authorizedUserHash, err := sshauth.HashUserID("authorized-user")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Try to authorize a different user
_, err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed)
}
authorizer.Update(config)
// All attempts should fail when no machine user mappings exist (fail closed)
_, err = authorizer.Authorize("user1", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
_, err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
_, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1}, // user1 and user2 can access root
"postgres": {1, 2}, // user2 and user3 can access postgres
"admin": {0}, // only user1 can access admin
},
}
authorizer.Update(config)
// user1 (index 0) should access root and admin
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
_, err = authorizer.Authorize("user1", "admin")
assert.NoError(t, err)
// user2 (index 1) should access root and postgres
_, err = authorizer.Authorize("user2", "root")
assert.NoError(t, err)
_, err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
// user3 (index 2) should access postgres
_, err = authorizer.Authorize("user3", "postgres")
assert.NoError(t, err)
}
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1}, // user1 and user2 can access root
"postgres": {1, 2}, // user2 and user3 can access postgres
"admin": {0}, // only user1 can access admin
},
}
authorizer.Update(config)
// user1 (index 0) should NOT access postgres
_, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 (index 1) should NOT access admin
_, err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access root
_, err = authorizer.Authorize("user3", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access admin
_, err = authorizer.Authorize("user3", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
}
func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0}, // only root is mapped
},
}
authorizer.Update(config)
// user1 should NOT access an unmapped OS user (fail closed)
_, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Empty user ID should fail
_, err = authorizer.Authorize("", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrEmptyUserID)
}
func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
authorizer := NewAuthorizer()
// Set up multiple authorized users
userHashes := make([]sshauth.UserIDHash, 10)
for i := 0; i < 10; i++ {
hash, err := sshauth.HashUserID("user" + string(rune('0'+i)))
require.NoError(t, err)
userHashes[i] = hash
}
// Create machine user mapping for all users
rootIndexes := make([]uint32, 10)
for i := 0; i < 10; i++ {
rootIndexes[i] = uint32(i)
}
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: userHashes,
MachineUsers: map[string][]uint32{
"root": rootIndexes,
},
}
authorizer.Update(config)
// All users should be authorized for root
for i := 0; i < 10; i++ {
_, err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
assert.NoError(t, err, "user%d should be authorized", i)
}
// User not in list should fail
_, err := authorizer.Authorize("unknown-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
authorizer := NewAuthorizer()
// Set up initial configuration
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{"root": {0}},
}
authorizer.Update(config)
// user1 should be authorized
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// Clear configuration
authorizer.Update(nil)
// user1 should no longer be authorized
_, err = authorizer.Authorize("user1", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
// Machine users with empty index lists should be filtered out
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0},
"postgres": {}, // empty list - should be filtered out
"admin": nil, // nil list - should be filtered out
},
}
authorizer.Update(config)
// root should work
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// postgres should fail (no mapping)
_, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// admin should fail (no mapping)
_, err = authorizer.Authorize("user1", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
authorizer := NewAuthorizer()
// Set up with custom user ID claim
user1Hash, err := sshauth.HashUserID("user@example.com")
require.NoError(t, err)
config := &Config{
UserIDClaim: "email",
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0},
},
}
authorizer.Update(config)
// Verify the custom claim is set
assert.Equal(t, "email", authorizer.GetUserIDClaim())
// Authorize with email as user ID
_, err = authorizer.Authorize("user@example.com", "root")
assert.NoError(t, err)
}
func TestAuthorizer_DefaultUserIDClaim(t *testing.T) {
authorizer := NewAuthorizer()
// Verify default claim
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
assert.Equal(t, "sub", authorizer.GetUserIDClaim())
// Set up with empty user ID claim (should use default)
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: "", // empty - should use default
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Should fall back to default
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
}
func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
authorizer := NewAuthorizer()
// Create a large authorized users list
const numUsers = 1000
userHashes := make([]sshauth.UserIDHash, numUsers)
for i := 0; i < numUsers; i++ {
hash, err := sshauth.HashUserID("user" + string(rune(i)))
require.NoError(t, err)
userHashes[i] = hash
}
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: userHashes,
MachineUsers: map[string][]uint32{
"root": {0, 500, 999}, // first, middle, and last user
},
}
authorizer.Update(config)
// First user should have access
_, err := authorizer.Authorize("user"+string(rune(0)), "root")
assert.NoError(t, err)
// Middle user should have access
_, err = authorizer.Authorize("user"+string(rune(500)), "root")
assert.NoError(t, err)
// Last user should have access
_, err = authorizer.Authorize("user"+string(rune(999)), "root")
assert.NoError(t, err)
// User not in mapping should NOT have access
_, err = authorizer.Authorize("user"+string(rune(100)), "root")
assert.Error(t, err)
}
func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1},
},
}
authorizer.Update(config)
// Test concurrent authorization calls (should be safe to read concurrently)
const numGoroutines = 100
errChan := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(idx int) {
user := "user1"
if idx%2 == 0 {
user = "user2"
}
_, err := authorizer.Authorize(user, "root")
errChan <- err
}(i)
}
// Wait for all goroutines to complete and collect errors
for i := 0; i < numGoroutines; i++ {
err := <-errChan
assert.NoError(t, err)
}
}
func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
// Configure with wildcard - all authorized users can access any OS user
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"*": {0, 1, 2}, // wildcard with all user indexes
},
}
authorizer.Update(config)
// All authorized users should be able to access any OS user
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
_, err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
_, err = authorizer.Authorize("user3", "admin")
assert.NoError(t, err)
_, err = authorizer.Authorize("user1", "ubuntu")
assert.NoError(t, err)
_, err = authorizer.Authorize("user2", "nginx")
assert.NoError(t, err)
_, err = authorizer.Authorize("user3", "docker")
assert.NoError(t, err)
}
func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
// Configure with wildcard
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"*": {0},
},
}
authorizer.Update(config)
// user1 should have access
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// Unauthorized user should still be denied even with wildcard
_, err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure with both wildcard and specific mappings
// Wildcard takes precedence for users in the wildcard index list
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"*": {0, 1}, // wildcard for both users
"root": {0}, // specific mapping that would normally restrict to user1 only
},
}
authorizer.Update(config)
// Both users should be able to access root via wildcard (takes precedence over specific mapping)
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
_, err = authorizer.Authorize("user2", "root")
assert.NoError(t, err)
// Both users should be able to access any other OS user via wildcard
_, err = authorizer.Authorize("user1", "postgres")
assert.NoError(t, err)
_, err = authorizer.Authorize("user2", "admin")
assert.NoError(t, err)
}
func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure WITHOUT wildcard - only specific mappings
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"root": {0}, // only user1
"postgres": {1}, // only user2
},
}
authorizer.Update(config)
// user1 can access root
_, err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// user2 can access postgres
_, err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
// user1 cannot access postgres
_, err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 cannot access root
_, err = authorizer.Authorize("user2", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// Neither can access unmapped OS users
_, err = authorizer.Authorize("user1", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
_, err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
// This test covers the scenario where wildcard exists with limited indexes.
// Only users whose indexes are in the wildcard list can access any OS user via wildcard.
// Other users can only access OS users they are explicitly mapped to.
authorizer := NewAuthorizer()
// Create two authorized user hashes (simulating the base64-encoded hashes in the config)
wasmHash, err := sshauth.HashUserID("wasm")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure with wildcard having only index 0, and specific mappings for other OS users
config := &Config{
UserIDClaim: "sub",
AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash},
MachineUsers: map[string][]uint32{
"*": {0}, // wildcard with only index 0 - only wasm has wildcard access
"alice": {1}, // specific mapping for user2
"bob": {1}, // specific mapping for user2
},
}
authorizer.Update(config)
// wasm (index 0) should access any OS user via wildcard
_, err = authorizer.Authorize("wasm", "root")
assert.NoError(t, err, "wasm should access root via wildcard")
_, err = authorizer.Authorize("wasm", "alice")
assert.NoError(t, err, "wasm should access alice via wildcard")
_, err = authorizer.Authorize("wasm", "bob")
assert.NoError(t, err, "wasm should access bob via wildcard")
_, err = authorizer.Authorize("wasm", "postgres")
assert.NoError(t, err, "wasm should access postgres via wildcard")
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
_, err = authorizer.Authorize("user2", "alice")
assert.NoError(t, err, "user2 should access alice via explicit mapping")
_, err = authorizer.Authorize("user2", "bob")
assert.NoError(t, err, "user2 should access bob via explicit mapping")
_, err = authorizer.Authorize("user2", "root")
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
_, err = authorizer.Authorize("user2", "postgres")
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// Unauthorized user should still be denied
_, err = authorizer.Authorize("user3", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
@@ -551,14 +550,15 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
defer func() {
if err := localConn.Close(); err != nil {
log.Debugf("local connection close error: %v", err)
log.Debugf("local port forwarding: close local connection: %v", err)
}
}()
channel, err := c.client.Dial("tcp", remoteAddr)
if err != nil {
if strings.Contains(err.Error(), "administratively prohibited") {
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n")
var openErr *ssh.OpenChannelError
if errors.As(err, &openErr) && openErr.Reason == ssh.Prohibited {
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: port forwarding is disabled\n")
} else {
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
}
@@ -566,19 +566,11 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
}
defer func() {
if err := channel.Close(); err != nil {
log.Debugf("remote channel close error: %v", err)
log.Debugf("local port forwarding: close remote channel: %v", err)
}
}()
go func() {
if _, err := io.Copy(channel, localConn); err != nil {
log.Debugf("local forward copy error (local->remote): %v", err)
}
}()
if _, err := io.Copy(localConn, channel); err != nil {
log.Debugf("local forward copy error (remote->local): %v", err)
}
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
}
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
@@ -633,7 +625,7 @@ func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
return fmt.Errorf("send tcpip-forward request: %w", err)
}
if !ok {
return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)")
return fmt.Errorf("remote port forwarding denied by server")
}
return nil
}
@@ -676,7 +668,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
}
defer func() {
if err := channel.Close(); err != nil {
log.Debugf("remote channel close error: %v", err)
log.Debugf("remote port forwarding: close remote channel: %v", err)
}
}()
@@ -688,19 +680,11 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
}
defer func() {
if err := localConn.Close(); err != nil {
log.Debugf("local connection close error: %v", err)
log.Debugf("remote port forwarding: close local connection: %v", err)
}
}()
go func() {
if _, err := io.Copy(localConn, channel); err != nil {
log.Debugf("remote forward copy error (remote->local): %v", err)
}
}()
if _, err := io.Copy(channel, localConn); err != nil {
log.Debugf("remote forward copy error (local->remote): %v", err)
}
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
}
// tcpipForwardMsg represents the structure for tcpip-forward requests

View File

@@ -193,3 +193,64 @@ func buildAddressList(hostname string, remote net.Addr) []string {
}
return addresses
}
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
// It waits for both directions to complete before returning.
// The caller is responsible for closing the connections.
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
<-done
<-done
}
func isExpectedCopyError(err error) bool {
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
}
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
// It waits for both directions to complete or for context cancellation before returning.
// Both connections are closed when the function returns.
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
case <-done:
select {
case <-ctx.Done():
case <-done:
}
}
_ = conn1.Close()
_ = conn2.Close()
}

View File

@@ -2,6 +2,7 @@ package proxy
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
@@ -42,6 +43,14 @@ type SSHProxy struct {
conn *grpc.ClientConn
daemonClient proto.DaemonServiceClient
browserOpener func(string) error
mu sync.RWMutex
backendClient *cryptossh.Client
// jwtToken is set once in runProxySSHServer before any handlers are called,
// so concurrent access is safe without additional synchronization.
jwtToken string
forwardedChannelsOnce sync.Once
}
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) {
@@ -63,6 +72,17 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browse
}
func (p *SSHProxy) Close() error {
p.mu.Lock()
backendClient := p.backendClient
p.backendClient = nil
p.mu.Unlock()
if backendClient != nil {
if err := backendClient.Close(); err != nil {
log.Debugf("close backend client: %v", err)
}
}
if p.conn != nil {
return p.conn.Close()
}
@@ -77,16 +97,16 @@ func (p *SSHProxy) Connect(ctx context.Context) error {
return fmt.Errorf(jwtAuthErrorMsg, err)
}
return p.runProxySSHServer(ctx, jwtToken)
log.Debugf("JWT authentication successful, starting proxy to %s:%d", p.targetHost, p.targetPort)
return p.runProxySSHServer(jwtToken)
}
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
p.jwtToken = jwtToken
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
sshServer := &ssh.Server{
Handler: func(s ssh.Session) {
p.handleSSHSession(ctx, s, jwtToken)
},
Handler: p.handleSSHSession,
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": p.directTCPIPHandler,
@@ -119,15 +139,20 @@ func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error
return nil
}
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
return
}
defer func() { _ = sshClient.Close() }()
if !isPty && !hasCommand {
p.handleNonInteractiveSession(session, sshClient)
return
}
serverSession, err := sshClient.NewSession()
if err != nil {
@@ -140,7 +165,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
ptyReq, winCh, isPty := session.Pty()
if isPty {
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
log.Debugf("PTY request to backend: %v", err)
@@ -155,7 +179,7 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
}()
}
if len(session.Command()) > 0 {
if hasCommand {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
@@ -176,12 +200,29 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
var exitErr *cryptossh.ExitError
if errors.As(err, &exitErr) {
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
log.Debugf("set exit status: %v", exitErr)
if err := session.Exit(exitErr.ExitStatus()); err != nil {
log.Debugf("set exit status: %v", err)
}
}
}
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
// Create a backend session to mirror the client's session request.
// This keeps the connection alive on the server side while port forwarding channels operate.
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
return
}
defer func() { _ = serverSession.Close() }()
<-session.Context().Done()
if err := session.Exit(0); err != nil {
log.Debugf("session exit: %v", err)
}
}
func generateHostKey() (ssh.Signer, error) {
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
@@ -250,8 +291,52 @@ func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
// directTCPIPHandler handles local port forwarding (direct-tcpip channel).
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, sshCtx ssh.Context) {
var payload struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
_, _ = fmt.Fprintf(p.stderr, "parse direct-tcpip payload: %v\n", err)
_ = newChan.Reject(cryptossh.ConnectionFailed, "invalid payload")
return
}
dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort)
log.Debugf("local port forwarding: %s", dest)
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "backend connection for port forwarding: %v\n", err)
_ = newChan.Reject(cryptossh.ConnectionFailed, "backend connection failed")
return
}
backendChan, backendReqs, err := backendClient.OpenChannel("direct-tcpip", newChan.ExtraData())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "open backend channel for %s: %v\n", dest, err)
var openErr *cryptossh.OpenChannelError
if errors.As(err, &openErr) {
_ = newChan.Reject(openErr.Reason, openErr.Message)
} else {
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
}
return
}
go cryptossh.DiscardRequests(backendReqs)
clientChan, clientReqs, err := newChan.Accept()
if err != nil {
log.Debugf("local port forwarding: accept channel: %v", err)
_ = backendChan.Close()
return
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
@@ -354,12 +439,143 @@ func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.Wr
}
}
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return false, []byte("port forwarding not supported in proxy")
// tcpipForwardHandler handles remote port forwarding (tcpip-forward request).
func (p *SSHProxy) tcpipForwardHandler(sshCtx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
var reqPayload struct {
Host string
Port uint32
}
if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil {
_, _ = fmt.Fprintf(p.stderr, "parse tcpip-forward payload: %v\n", err)
return false, nil
}
log.Debugf("tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port)
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "backend connection for remote port forwarding: %v\n", err)
return false, nil
}
ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "forward tcpip-forward request for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err)
return false, nil
}
if ok {
actualPort := reqPayload.Port
if reqPayload.Port == 0 && len(payload) >= 4 {
actualPort = binary.BigEndian.Uint32(payload)
}
log.Debugf("remote port forwarding established for %s:%d", reqPayload.Host, actualPort)
p.forwardedChannelsOnce.Do(func() {
go p.handleForwardedChannels(sshCtx, backendClient)
})
}
return ok, payload
}
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return true, nil
// cancelTcpipForwardHandler handles cancel-tcpip-forward request.
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
var reqPayload struct {
Host string
Port uint32
}
if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil {
_, _ = fmt.Fprintf(p.stderr, "parse cancel-tcpip-forward payload: %v\n", err)
return false, nil
}
log.Debugf("cancel-tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port)
backendClient := p.getBackendClient()
if backendClient == nil {
return false, nil
}
ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "cancel-tcpip-forward for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err)
return false, nil
}
return ok, payload
}
// getOrCreateBackendClient returns the existing backend client or creates a new one.
func (p *SSHProxy) getOrCreateBackendClient(ctx context.Context, user string) (*cryptossh.Client, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.backendClient != nil {
return p.backendClient, nil
}
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
log.Debugf("connecting to backend %s", targetAddr)
client, err := p.dialBackend(ctx, targetAddr, user, p.jwtToken)
if err != nil {
return nil, err
}
log.Debugf("backend connection established to %s", targetAddr)
p.backendClient = client
return client, nil
}
// getBackendClient returns the existing backend client or nil.
func (p *SSHProxy) getBackendClient() *cryptossh.Client {
p.mu.RLock()
defer p.mu.RUnlock()
return p.backendClient
}
// handleForwardedChannels handles forwarded-tcpip channels from the backend for remote port forwarding.
// When the backend receives incoming connections on the forwarded port, it sends them as
// "forwarded-tcpip" channels which we need to proxy to the client.
func (p *SSHProxy) handleForwardedChannels(sshCtx ssh.Context, backendClient *cryptossh.Client) {
sshConn, ok := sshCtx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if !ok || sshConn == nil {
log.Debugf("no SSH connection in context for forwarded channels")
return
}
channelChan := backendClient.HandleChannelOpen("forwarded-tcpip")
for {
select {
case <-sshCtx.Done():
return
case newChannel, ok := <-channelChan:
if !ok {
return
}
go p.handleForwardedChannel(sshCtx, sshConn, newChannel)
}
}
}
// handleForwardedChannel handles a single forwarded-tcpip channel from the backend.
func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh.ServerConn, newChannel cryptossh.NewChannel) {
backendChan, backendReqs, err := newChannel.Accept()
if err != nil {
log.Debugf("remote port forwarding: accept from backend: %v", err)
return
}
go cryptossh.DiscardRequests(backendReqs)
clientChan, clientReqs, err := sshConn.OpenChannel("forwarded-tcpip", newChannel.ExtraData())
if err != nil {
log.Debugf("remote port forwarding: open to client: %v", err)
_ = backendChan.Close()
return
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {

View File

@@ -27,9 +27,11 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestMain(m *testing.M) {
@@ -137,6 +139,21 @@ func TestSSHProxy_Connect(t *testing.T) {
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
// Configure SSH authorization for the test user
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0}, // Index 0 in AuthorizedUsers
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
@@ -150,10 +167,10 @@ func TestSSHProxy_Connect(t *testing.T) {
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
@@ -347,12 +364,12 @@ func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"sub": user,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}

View File

@@ -23,10 +23,12 @@ import (
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestJWTEnforcement(t *testing.T) {
@@ -577,6 +579,22 @@ func TestJWTAuthentication(t *testing.T) {
tc.setupServer(server)
}
// Always set up authorization for test-user to ensure tests fail at JWT validation stage
testUserHash, err := sshuserhash.HashUserID("test-user")
require.NoError(t, err)
// Get current OS username for machine user mapping
currentUser := testutil.GetTestUsername(t)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
currentUser: {0}, // Allow test-user (index 0) to access current OS user
},
}
server.UpdateSSHAuth(authConfig)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())

View File

@@ -1,25 +1,32 @@
// Package server implements port forwarding for the SSH server.
//
// Security note: Port forwarding runs in the main server process without privilege separation.
// The attack surface is primarily io.Copy through well-tested standard library code, making it
// lower risk than shell execution which uses privilege-separated child processes. We enforce
// user-level port restrictions: non-privileged users cannot bind to ports < 1024.
package server
import (
"encoding/binary"
"fmt"
"io"
"net"
"runtime"
"strconv"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
// SessionKey uniquely identifies an SSH session
type SessionKey string
const privilegedPortThreshold = 1024
// ConnectionKey uniquely identifies a port forwarding connection within a session
type ConnectionKey string
// sessionKey uniquely identifies an SSH session
type sessionKey string
// ForwardKey uniquely identifies a port forwarding listener
type ForwardKey string
// forwardKey uniquely identifies a port forwarding listener
type forwardKey string
// tcpipForwardMsg represents the structure for tcpip-forward SSH requests
type tcpipForwardMsg struct {
@@ -47,34 +54,32 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
allowRemote := s.allowRemotePortForwarding
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowLocal {
log.Warnf("local port forwarding denied for %s from %s: disabled by configuration",
net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr())
logger.Warnf("local port forwarding denied for %s:%d: disabled", dstHost, dstPort)
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err)
logger.Warnf("local port forwarding denied for %s:%d: %v", dstHost, dstPort, err)
return false
}
log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort)
return true
}
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
logger := s.getRequestLogger(ctx)
if !allowRemote {
log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration",
net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr())
logger.Warnf("remote port forwarding denied for %s:%d: disabled", bindHost, bindPort)
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err)
logger.Warnf("remote port forwarding denied for %s:%d: %v", bindHost, bindPort, err)
return false
}
log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort)
return true
}
@@ -82,23 +87,20 @@ func (s *Server) configurePortForwarding(server *ssh.Server) {
}
// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
// Returns nil if allowed, error if denied.
// For remote port forwarding (binding), it enforces that non-privileged users cannot bind to
// ports below 1024, mirroring the restriction they would face if binding directly.
//
// Note: FeatureSupportsUserSwitch is true because we accept requests from any authenticated user,
// though we don't actually switch users - port forwarding runs in the server process. The resolved
// user is used for privileged port access checks.
func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
if ctx == nil {
return fmt.Errorf("%s port forwarding denied: no context", forwardType)
}
username := ctx.User()
remoteAddr := "unknown"
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
}
logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port})
result := s.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: username,
FeatureSupportsUserSwitch: false,
RequestedUsername: ctx.User(),
FeatureSupportsUserSwitch: true,
FeatureName: forwardType + " port forwarding",
})
@@ -106,12 +108,42 @@ func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType stri
return result.Error
}
logger.Debugf("%s port forwarding allowed: user %s validated (port %d)",
forwardType, result.User.Username, port)
if err := s.checkPrivilegedPortAccess(forwardType, port, result); err != nil {
return err
}
return nil
}
// checkPrivilegedPortAccess enforces that non-privileged users cannot bind to privileged ports.
// This applies to remote port forwarding where the server binds a port on behalf of the user.
// On Windows, there is no privileged port restriction, so this check is skipped.
func (s *Server) checkPrivilegedPortAccess(forwardType string, port uint32, result PrivilegeCheckResult) error {
if runtime.GOOS == "windows" {
return nil
}
isBindOperation := forwardType == "remote" || forwardType == "tcpip-forward"
if !isBindOperation {
return nil
}
// Port 0 means "pick any available port", which will be >= 1024
if port == 0 || port >= privilegedPortThreshold {
return nil
}
if result.User != nil && isPrivilegedUsername(result.User.Username) {
return nil
}
username := "unknown"
if result.User != nil {
username = result.User.Username
}
return fmt.Errorf("user %s cannot bind to privileged port %d (requires root)", username, port)
}
// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
logger := s.getRequestLogger(ctx)
@@ -132,8 +164,6 @@ func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *crypto
return false, nil
}
logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port)
sshConn, err := s.getSSHConnection(ctx)
if err != nil {
logger.Warnf("tcpip-forward request denied: %v", err)
@@ -153,8 +183,10 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
return false, nil
}
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
if s.removeRemoteForwardListener(key) {
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, payload.Port)
s.removeConnectionPortForward(ctx.RemoteAddr(), forwardAddr)
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
return true, nil
}
@@ -165,14 +197,11 @@ func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *
// handleRemoteForwardListener handles incoming connections for remote port forwarding.
func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
log.Debugf("starting remote forward listener handler for %s:%d", host, port)
logger := s.getRequestLogger(ctx)
defer func() {
log.Debugf("cleaning up remote forward listener for %s:%d", host, port)
if err := ln.Close(); err != nil {
log.Debugf("remote forward listener close error: %v", err)
} else {
log.Debugf("remote forward listener closed successfully for %s:%d", host, port)
logger.Debugf("remote forward listener close error for %s:%d: %v", host, port, err)
}
}()
@@ -196,28 +225,43 @@ func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, h
select {
case result := <-acceptChan:
if result.err != nil {
log.Debugf("remote forward accept error: %v", result.err)
logger.Debugf("remote forward accept error: %v", result.err)
return
}
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
case <-ctx.Done():
log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port)
logger.Debugf("remote forward listener shutting down for %s:%d", host, port)
return
}
}
}
// getRequestLogger creates a logger with user and remote address context
// getRequestLogger creates a logger with session/conn and jwt_user context
func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
remoteAddr := "unknown"
username := "unknown"
if ctx != nil {
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
sessionKey := s.findSessionKeyByContext(ctx)
s.mu.RLock()
defer s.mu.RUnlock()
if state, exists := s.sessions[sessionKey]; exists {
logger := log.WithField("session", sessionKey)
if state.jwtUsername != "" {
logger = logger.WithField("jwt_user", state.jwtUsername)
}
username = ctx.User()
return logger
}
return log.WithFields(log.Fields{"user": username, "remote": remoteAddr})
if ctx.RemoteAddr() != nil {
if connState, exists := s.connections[connKey(ctx.RemoteAddr().String())]; exists {
return s.connLogger(connState)
}
}
remoteAddr := "unknown"
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
}
return log.WithField("session", fmt.Sprintf("%s@%s", ctx.User(), remoteAddr))
}
// isRemotePortForwardingAllowed checks if remote port forwarding is enabled
@@ -227,6 +271,13 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
return s.allowRemotePortForwarding
}
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
func (s *Server) isPortForwardingEnabled() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
}
// parseTcpipForwardRequest parses the SSH request payload
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
var payload tcpipForwardMsg
@@ -267,10 +318,11 @@ func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
}
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
key := forwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
s.storeRemoteForwardListener(key, ln)
s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String())
forwardAddr := fmt.Sprintf("-R %s:%d", payload.Host, actualPort)
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
response := make([]byte, 4)
@@ -288,44 +340,34 @@ type acceptResult struct {
// handleRemoteForwardConnection handles a single remote port forwarding connection
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
sessionKey := s.findSessionKeyByContext(ctx)
connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
logger := log.WithFields(log.Fields{
"session": sessionKey,
"conn": connID,
})
logger := s.getRequestLogger(ctx)
defer func() {
if err := conn.Close(); err != nil {
logger.Debugf("connection close error: %v", err)
}
}()
sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if sshConn == nil {
sshConn, ok := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if !ok || sshConn == nil {
logger.Debugf("remote forward: no SSH connection in context")
_ = conn.Close()
return
}
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok {
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
_ = conn.Close()
return
}
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr)
if err != nil {
logger.Debugf("open forward channel: %v", err)
logger.Debugf("open forward channel for %s:%d: %v", host, port, err)
_ = conn.Close()
return
}
s.proxyForwardConnection(ctx, logger, conn, channel)
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
}
// openForwardChannel creates an SSH forwarded-tcpip channel
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) {
logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port)
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr) (cryptossh.Channel, error) {
payload := struct {
ConnectedAddress string
ConnectedPort uint32
@@ -346,41 +388,3 @@ func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string,
go cryptossh.DiscardRequests(reqs)
return channel, nil
}
// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(channel, conn); err != nil {
logger.Debugf("copy error (conn->channel): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn, channel); err != nil {
logger.Debugf("copy error (channel->conn): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
logger.Debugf("session ended, closing connections")
case <-done:
// First copy finished, wait for second copy or context cancellation
select {
case <-ctx.Done():
logger.Debugf("session ended, closing connections")
case <-done:
}
}
if err := channel.Close(); err != nil {
logger.Debugf("channel close error: %v", err)
}
if err := conn.Close(); err != nil {
logger.Debugf("connection close error: %v", err)
}
}

View File

@@ -9,6 +9,7 @@ import (
"io"
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
@@ -21,6 +22,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
@@ -39,6 +41,11 @@ const (
msgPrivilegedUserDisabled = "privileged user login is disabled"
cmdInteractiveShell = "<interactive shell>"
cmdPortForwarding = "<port forwarding>"
cmdSFTP = "<sftp>"
cmdNonInteractive = "<idle>"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60
)
@@ -89,10 +96,10 @@ func logSessionExitError(logger *log.Entry, err error) {
}
}
// safeLogCommand returns a safe representation of the command for logging
// safeLogCommand returns a safe representation of the command for logging.
func safeLogCommand(cmd []string) string {
if len(cmd) == 0 {
return "<interactive shell>"
return cmdInteractiveShell
}
if len(cmd) == 1 {
return cmd[0]
@@ -100,26 +107,50 @@ func safeLogCommand(cmd []string) string {
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
}
type sshConnectionState struct {
hasActivePortForward bool
username string
remoteAddr string
// connState tracks the state of an SSH connection for port forwarding and status display.
type connState struct {
username string
remoteAddr net.Addr
portForwards []string
jwtUsername string
}
// authKey uniquely identifies an authentication attempt by username and remote address.
// Used to temporarily store JWT username between passwordHandler and sessionHandler.
type authKey string
// connKey uniquely identifies an SSH connection by its remote address.
// Used to track authenticated connections for status display and port forwarding.
type connKey string
func newAuthKey(username string, remoteAddr net.Addr) authKey {
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
}
// sessionState tracks an active SSH session (shell, command, or subsystem like SFTP).
type sessionState struct {
session ssh.Session
sessionType string
jwtUsername string
}
type Server struct {
sshServer *ssh.Server
mu sync.RWMutex
hostKeyPEM []byte
sessions map[SessionKey]ssh.Session
sessionCancels map[ConnectionKey]context.CancelFunc
sessionJWTUsers map[SessionKey]string
pendingAuthJWT map[authKey]string
sshServer *ssh.Server
listener net.Listener
mu sync.RWMutex
hostKeyPEM []byte
// sessions tracks active SSH sessions (shell, command, SFTP).
// These are created when a client opens a session channel and requests shell/exec/subsystem.
sessions map[sessionKey]*sessionState
// pendingAuthJWT temporarily stores JWT username during the auth→session handoff.
// Populated in passwordHandler, consumed in sessionHandler/sftpSubsystemHandler.
pendingAuthJWT map[authKey]string
// connections tracks all SSH connections by their remote address.
// Populated at authentication time, stores JWT username and port forwards for status display.
connections map[connKey]*connState
allowLocalPortForwarding bool
allowRemotePortForwarding bool
@@ -131,13 +162,14 @@ type Server struct {
wgAddress wgaddr.Address
remoteForwardListeners map[ForwardKey]net.Listener
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
remoteForwardListeners map[forwardKey]net.Listener
jwtValidator *jwt.Validator
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
authorizer *sshauth.Authorizer
suSupportsPty bool
loginIsUtilLinux bool
}
@@ -164,6 +196,7 @@ type SessionInfo struct {
RemoteAddress string
Command string
JWTUsername string
PortForwards []string
}
// New creates an SSH server instance with the provided host key and optional JWT configuration
@@ -172,13 +205,13 @@ func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{},
hostKeyPEM: config.HostKeyPEM,
sessions: make(map[SessionKey]ssh.Session),
sessionJWTUsers: make(map[SessionKey]string),
sessions: make(map[sessionKey]*sessionState),
pendingAuthJWT: make(map[authKey]string),
remoteForwardListeners: make(map[ForwardKey]net.Listener),
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
remoteForwardListeners: make(map[forwardKey]net.Listener),
connections: make(map[connKey]*connState),
jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT,
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
}
return s
@@ -207,6 +240,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
return fmt.Errorf("create SSH server: %w", err)
}
s.listener = ln
s.sshServer = sshServer
log.Infof("SSH server started on %s", addrDesc)
@@ -259,16 +293,11 @@ func (s *Server) Stop() error {
}
s.sshServer = nil
s.listener = nil
maps.Clear(s.sessions)
maps.Clear(s.sessionJWTUsers)
maps.Clear(s.pendingAuthJWT)
maps.Clear(s.sshConnections)
for _, cancelFunc := range s.sessionCancels {
cancelFunc()
}
maps.Clear(s.sessionCancels)
maps.Clear(s.connections)
for _, listener := range s.remoteForwardListeners {
if err := listener.Close(); err != nil {
@@ -280,32 +309,82 @@ func (s *Server) Stop() error {
return nil
}
// GetStatus returns the current status of the SSH server and active sessions
// Addr returns the address the SSH server is listening on, or nil if the server is not running
func (s *Server) Addr() net.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
if s.listener == nil {
return nil
}
return s.listener.Addr()
}
// GetStatus returns the current status of the SSH server and active sessions.
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
s.mu.RLock()
defer s.mu.RUnlock()
enabled = s.sshServer != nil
reportedAddrs := make(map[string]bool)
for sessionKey, session := range s.sessions {
cmd := "<interactive shell>"
if len(session.Command()) > 0 {
cmd = safeLogCommand(session.Command())
for _, state := range s.sessions {
info := s.buildSessionInfo(state)
reportedAddrs[info.RemoteAddress] = true
sessions = append(sessions, info)
}
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
for key, connState := range s.connections {
remoteAddr := string(key)
if reportedAddrs[remoteAddr] {
continue
}
cmd := cmdNonInteractive
if len(connState.portForwards) > 0 {
cmd = cmdPortForwarding
}
jwtUsername := s.sessionJWTUsers[sessionKey]
sessions = append(sessions, SessionInfo{
Username: session.User(),
RemoteAddress: session.RemoteAddr().String(),
Username: connState.username,
RemoteAddress: remoteAddr,
Command: cmd,
JWTUsername: jwtUsername,
JWTUsername: connState.jwtUsername,
PortForwards: connState.portForwards,
})
}
return enabled, sessions
}
func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
session := state.session
cmd := state.sessionType
if cmd == "" {
cmd = safeLogCommand(session.Command())
}
remoteAddr := session.RemoteAddr().String()
info := SessionInfo{
Username: session.User(),
RemoteAddress: remoteAddr,
Command: cmd,
JWTUsername: state.jwtUsername,
}
connState, exists := s.connections[connKey(remoteAddr)]
if !exists {
return info
}
info.PortForwards = connState.portForwards
if len(connState.portForwards) > 0 && (cmd == cmdInteractiveShell || cmd == cmdNonInteractive) {
info.Command = cmdPortForwarding
}
return info
}
// SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock()
@@ -320,6 +399,19 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.wgAddress = addr
}
// UpdateSSHAuth updates the SSH fine-grained access control configuration
// This should be called when network map updates include new SSH auth configuration
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
s.mu.Lock()
defer s.mu.Unlock()
// Reset JWT validator/extractor to pick up new userIDClaim
s.jwtValidator = nil
s.jwtExtractor = nil
s.authorizer.Update(config)
}
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
func (s *Server) ensureJWTValidator() error {
s.mu.RLock()
@@ -328,6 +420,7 @@ func (s *Server) ensureJWTValidator() error {
return nil
}
config := s.jwtConfig
authorizer := s.authorizer
s.mu.RUnlock()
if config == nil {
@@ -343,9 +436,16 @@ func (s *Server) ensureJWTValidator() error {
true,
)
extractor := jwt.NewClaimsExtractor(
// Use custom userIDClaim from authorizer if available
extractorOptions := []jwt.ClaimsExtractorOption{
jwt.WithAudience(config.Audience),
)
}
if authorizer.GetUserIDClaim() != "" {
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim())
}
extractor := jwt.NewClaimsExtractor(extractorOptions...)
s.mu.Lock()
defer s.mu.Unlock()
@@ -493,59 +593,131 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
}
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
osUsername := ctx.User()
remoteAddr := ctx.RemoteAddr()
logger := s.getRequestLogger(ctx)
if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
logger.Errorf("JWT validator initialization failed: %v", err)
return false
}
token, err := s.validateJWTToken(password)
if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
logger.Warnf("JWT authentication failed: %v", err)
return false
}
userAuth, err := s.extractAndValidateUser(token)
if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
logger.Warnf("user validation failed: %v", err)
return false
}
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
logger = logger.WithField("jwt_user", userAuth.UserId)
s.mu.RLock()
authorizer := s.authorizer
s.mu.RUnlock()
msg, err := authorizer.Authorize(userAuth.UserId, osUsername)
if err != nil {
logger.Warnf("SSH auth denied: %v", err)
return false
}
logger.Infof("SSH auth %s", msg)
key := newAuthKey(osUsername, remoteAddr)
remoteAddrStr := ctx.RemoteAddr().String()
s.mu.Lock()
s.pendingAuthJWT[key] = userAuth.UserId
s.connections[connKey(remoteAddrStr)] = &connState{
username: ctx.User(),
remoteAddr: ctx.RemoteAddr(),
jwtUsername: userAuth.UserId,
}
s.mu.Unlock()
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
return true
}
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
func (s *Server) addConnectionPortForward(username string, remoteAddr net.Addr, forwardAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
if state, exists := s.sshConnections[sshConn]; exists {
state.hasActivePortForward = true
} else {
s.sshConnections[sshConn] = &sshConnectionState{
hasActivePortForward: true,
username: username,
remoteAddr: remoteAddr,
key := connKey(remoteAddr.String())
if state, exists := s.connections[key]; exists {
if !slices.Contains(state.portForwards, forwardAddr) {
state.portForwards = append(state.portForwards, forwardAddr)
}
return
}
// Connection not in connections (non-JWT auth path)
s.connections[key] = &connState{
username: username,
remoteAddr: remoteAddr,
portForwards: []string{forwardAddr},
jwtUsername: s.pendingAuthJWT[newAuthKey(username, remoteAddr)],
}
}
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
// We can't extract the SSH connection from net.Conn directly
// Connection cleanup will happen during session cleanup or via timeout
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
func (s *Server) removeConnectionPortForward(remoteAddr net.Addr, forwardAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
state, exists := s.connections[connKey(remoteAddr.String())]
if !exists {
return
}
state.portForwards = slices.DeleteFunc(state.portForwards, func(addr string) bool {
return addr == forwardAddr
})
}
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
// trackedConn wraps a net.Conn to detect when it closes
type trackedConn struct {
net.Conn
server *Server
remoteAddr string
onceClose sync.Once
}
func (c *trackedConn) Close() error {
err := c.Conn.Close()
c.onceClose.Do(func() {
c.server.handleConnectionClose(c.remoteAddr)
})
return err
}
func (s *Server) handleConnectionClose(remoteAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
key := connKey(remoteAddr)
state, exists := s.connections[key]
if exists && len(state.portForwards) > 0 {
s.connLogger(state).Info("port forwarding connection closed")
}
delete(s.connections, key)
}
func (s *Server) connLogger(state *connState) *log.Entry {
logger := log.WithField("session", fmt.Sprintf("%s@%s", state.username, state.remoteAddr))
if state.jwtUsername != "" {
logger = logger.WithField("jwt_user", state.jwtUsername)
}
return logger
}
func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey {
if ctx == nil {
return "unknown"
}
// Try to match by SSH connection
sshConn := ctx.Value(ssh.ContextKeyConn)
if sshConn == nil {
return "unknown"
@@ -554,19 +726,14 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
s.mu.RLock()
defer s.mu.RUnlock()
// Look through sessions to find one with matching connection
for sessionKey, session := range s.sessions {
if session.Context().Value(ssh.ContextKeyConn) == sshConn {
for sessionKey, state := range s.sessions {
if state.session.Context().Value(ssh.ContextKeyConn) == sshConn {
return sessionKey
}
}
// If no session found, this might be during early connection setup
// Return a temporary key that we'll fix up later
if ctx.User() != "" && ctx.RemoteAddr() != nil {
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
return tempKey
return sessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
}
return "unknown"
@@ -607,7 +774,11 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
}
log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
return conn
return &trackedConn{
Conn: conn,
server: s,
remoteAddr: conn.RemoteAddr().String(),
}
}
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
@@ -635,9 +806,8 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
"tcpip-forward": s.tcpipForwardHandler,
"cancel-tcpip-forward": s.cancelTcpipForwardHandler,
},
ConnCallback: s.connectionValidator,
ConnectionFailedCallback: s.connectionCloseHandler,
Version: serverVersion,
ConnCallback: s.connectionValidator,
Version: serverVersion,
}
if s.jwtEnabled {
@@ -653,13 +823,13 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
return server, nil
}
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
func (s *Server) storeRemoteForwardListener(key forwardKey, ln net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteForwardListeners[key] = ln
}
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
func (s *Server) removeRemoteForwardListener(key forwardKey) bool {
s.mu.Lock()
defer s.mu.Unlock()
@@ -677,6 +847,8 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
}
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
logger := s.getRequestLogger(ctx)
var payload struct {
Host string
Port uint32
@@ -686,7 +858,7 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
log.Debugf("channel reject error: %v", err)
logger.Debugf("channel reject error: %v", err)
}
return
}
@@ -696,19 +868,20 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.mu.RUnlock()
if !allowLocal {
log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port)
logger.Warnf("local port forwarding denied for %s:%d: disabled", payload.Host, payload.Port)
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
return
}
// Check privilege requirements for the destination port
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
logger.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
return
}
log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
forwardAddr := fmt.Sprintf("-L %s:%d", payload.Host, payload.Port)
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
}

View File

@@ -224,6 +224,96 @@ func TestServer_PortForwardingRestriction(t *testing.T) {
}
}
func TestServer_PrivilegedPortAccess(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
}
server := New(serverConfig)
server.SetAllowRemotePortForwarding(true)
tests := []struct {
name string
forwardType string
port uint32
username string
expectError bool
errorMsg string
skipOnWindows bool
}{
{
name: "non-root user remote forward privileged port",
forwardType: "remote",
port: 80,
username: "testuser",
expectError: true,
errorMsg: "cannot bind to privileged port",
skipOnWindows: true,
},
{
name: "non-root user tcpip-forward privileged port",
forwardType: "tcpip-forward",
port: 443,
username: "testuser",
expectError: true,
errorMsg: "cannot bind to privileged port",
skipOnWindows: true,
},
{
name: "non-root user remote forward unprivileged port",
forwardType: "remote",
port: 8080,
username: "testuser",
expectError: false,
},
{
name: "non-root user remote forward port 0",
forwardType: "remote",
port: 0,
username: "testuser",
expectError: false,
},
{
name: "root user remote forward privileged port",
forwardType: "remote",
port: 22,
username: "root",
expectError: false,
},
{
name: "local forward privileged port allowed for non-root",
forwardType: "local",
port: 80,
username: "testuser",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.skipOnWindows && runtime.GOOS == "windows" {
t.Skip("Windows does not have privileged port restrictions")
}
result := PrivilegeCheckResult{
Allowed: true,
User: &user.User{Username: tt.username},
}
err := server.checkPrivilegedPortAccess(tt.forwardType, tt.port, result)
if tt.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
require.NoError(t, err)
}
})
}
}
func TestServer_PortConflictHandling(t *testing.T) {
// Test that multiple sessions requesting the same local port are handled naturally by the OS
// Get current user for SSH connection
@@ -392,3 +482,95 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
})
}
}
func TestServer_PortForwardingOnlySession(t *testing.T) {
// Test that sessions without PTY and command are allowed when port forwarding is enabled
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
tests := []struct {
name string
allowLocalForwarding bool
allowRemoteForwarding bool
expectAllowed bool
description string
}{
{
name: "session_allowed_with_local_forwarding",
allowLocalForwarding: true,
allowRemoteForwarding: false,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
},
{
name: "session_allowed_with_remote_forwarding",
allowLocalForwarding: false,
allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
},
{
name: "session_allowed_with_both",
allowLocalForwarding: true,
allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
},
{
name: "session_denied_without_forwarding",
allowLocalForwarding: false,
allowRemoteForwarding: false,
expectAllowed: false,
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
serverAddr := StartTestServer(t, server)
defer func() {
_ = server.Stop()
}()
// Connect to the server without requesting PTY or command
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
client, err := sshclient.Dial(ctx, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
_ = client.Close()
}()
// Execute a command without PTY - this simulates ssh -T with no command
// The server should either allow it (port forwarding enabled) or reject it
output, err := client.ExecuteCommand(ctx, "")
if tt.expectAllowed {
// When allowed, the session stays open until cancelled
// ExecuteCommand with empty command should return without error
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
assert.NotContains(t, output, "port forwarding is disabled",
"Output should not contain port forwarding disabled message")
} else if err != nil {
// When denied, we expect an error message about port forwarding being disabled
assert.Contains(t, err.Error(), "port forwarding is disabled",
"Should get port forwarding disabled message")
}
})
}
}

Some files were not shown because too many files have changed in this diff Show More