Compare commits

..

8 Commits

Author SHA1 Message Date
bcmmbaga
72513d7522 Skip network map calculation when client serial matches current
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-01-27 23:03:43 +03:00
Zoltan Papp
a1f1bf1f19 Merge branch 'main' into feat/network-map-serial 2025-12-18 15:59:53 +01:00
Zoltan Papp
b5dec3df39 Track network serial in engine 2025-12-18 15:27:49 +01:00
Hakan Sariman
20f5f00635 [client] Add unit tests for engine synchronization and Info flag copying
- Introduced tests for the Engine's handleSync method to verify behavior when SkipNetworkMapUpdate is true and when NetworkMap is nil.
- Added a test for the Info struct to ensure correct copying of flag values from one instance to another, while preserving unrelated fields.
2025-10-17 10:03:07 +03:00
Hakan Sariman
fc141cf3a3 [client] Refactor lastNetworkMapSerial handling in GrpcClient
- Removed atomic operations for lastNetworkMapSerial and replaced them with mutex-based methods for thread-safe access.
2025-09-29 18:49:23 +07:00
Hakan Sariman
d0c65fa08e [client] Add skipNetworkMapUpdate field to SyncResponse for conditional updates 2025-09-29 18:28:14 +07:00
Hakan Sariman
f241bfa339 Refactor flag setting in Info struct to use CopyFlagsFrom method 2025-09-29 15:38:35 +07:00
Hakan Sariman
4b2cd97d5f [client] Enhance SyncRequest with NetworkMap serial tracking
- Added `networkMapSerial` field to `SyncRequest` for tracking the last known network map serial number.
- Updated `GrpcClient` to store and utilize the last network map serial during sync operations, optimizing synchronization processes.
- Improved handling of system info updates to ensure accurate metadata is sent with sync requests.
2025-09-25 19:28:35 +07:00
139 changed files with 1198 additions and 12672 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.0"
SIGN_PIPE_VER: "v0.0.23"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -19,100 +19,6 @@ concurrency:
cancel-in-progress: true
jobs:
release_freebsd_port:
name: "FreeBSD Port / Build & Test"
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Generate FreeBSD port diff
run: bash release_files/freebsd-port-diff.sh
- name: Generate FreeBSD port issue body
run: bash release_files/freebsd-port-issue-body.sh
- name: Check if diff was generated
id: check_diff
run: |
if ls netbird-*.diff 1> /dev/null 2>&1; then
echo "diff_exists=true" >> $GITHUB_OUTPUT
else
echo "diff_exists=false" >> $GITHUB_OUTPUT
echo "No diff file generated (port may already be up to date)"
fi
- name: Extract version
if: steps.check_diff.outputs.diff_exists == 'true'
id: version
run: |
VERSION=$(ls netbird-*.diff | sed 's/netbird-\(.*\)\.diff/\1/')
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Generated files for version: $VERSION"
cat netbird-*.diff
- name: Test FreeBSD port
if: steps.check_diff.outputs.diff_exists == 'true'
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "15.0"
prepare: |
# Install required packages
pkg install -y git curl portlint go
# Install Go for building
GO_TARBALL="go1.24.10.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"
# Clone ports tree (shallow, only what we need)
git clone --depth 1 --filter=blob:none https://git.FreeBSD.org/ports.git /usr/ports
cd /usr/ports
run: |
set -e -x
export PATH=$PATH:/usr/local/go/bin
# Find the diff file
echo "Finding diff file..."
DIFF_FILE=$(find $PWD -name "netbird-*.diff" -type f 2>/dev/null | head -1)
echo "Found: $DIFF_FILE"
if [[ -z "$DIFF_FILE" ]]; then
echo "ERROR: Could not find diff file"
find ~ -name "*.diff" -type f 2>/dev/null || true
exit 1
fi
# Apply the generated diff from /usr/ports (diff has a/security/netbird/... paths)
cd /usr/ports
patch -p1 -V none < "$DIFF_FILE"
# Show patched Makefile
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
cd /usr/ports/security/netbird
export BATCH=yes
make package
pkg add ./work/pkg/netbird-*.pkg
netbird version | grep "$version"
echo "FreeBSD port test completed successfully!"
- name: Upload FreeBSD port files
if: steps.check_diff.outputs.diff_exists == 'true'
uses: actions/upload-artifact@v4
with:
name: freebsd-port-files
path: |
./netbird-*-issue.txt
./netbird-*.diff
retention-days: 30
release:
runs-on: ubuntu-latest-m
env:

View File

@@ -113,7 +113,7 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
<p float="left" align="middle">
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
<img src="https://docs.netbird.io/docs-static/img/architecture/high-level-dia.png" width="700"/>
</p>
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.

View File

@@ -59,6 +59,7 @@ func init() {
// Client struct manage the life circle of background service
type Client struct {
cfgFile string
tunAdapter device.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
@@ -67,16 +68,18 @@ type Client struct {
deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener
stateFile string
connectClient *internal.ConnectClient
}
// NewClient instantiate a new Client
func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{
cfgFile: platformFiles.ConfigurationFilePath(),
deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter,
@@ -84,20 +87,15 @@ func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAd
recorder: peer.NewRecorder(""),
ctxCancelLock: &sync.Mutex{},
networkChangeListener: networkChangeListener,
stateFile: platformFiles.StateFilePath(),
}
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: cfgFile,
ConfigPath: c.cfgFile,
})
if err != nil {
return err
@@ -124,22 +122,16 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: cfgFile,
ConfigPath: c.cfgFile,
})
if err != nil {
return err
@@ -157,8 +149,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
}
// Stop the internal client and free the resources

View File

@@ -1,257 +0,0 @@
//go:build android
package android
import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
const (
// Android-specific config filename (different from desktop default.json)
defaultConfigFilename = "netbird.cfg"
// Subdirectory for non-default profiles (must match Java Preferences.java)
profilesSubdir = "profiles"
// Android uses a single user context per app (non-empty username required by ServiceManager)
androidUsername = "android"
)
// Profile represents a profile for gomobile
type Profile struct {
Name string
IsActive bool
}
// ProfileArray wraps profiles for gomobile compatibility
type ProfileArray struct {
items []*Profile
}
// Length returns the number of profiles
func (p *ProfileArray) Length() int {
return len(p.items)
}
// Get returns the profile at index i
func (p *ProfileArray) Get(i int) *Profile {
if i < 0 || i >= len(p.items) {
return nil
}
return p.items[i]
}
/*
/data/data/io.netbird.client/files/ ← configDir parameter
├── netbird.cfg ← Default profile config
├── state.json ← Default profile state
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
└── profiles/ ← Subdirectory for non-default profiles
├── work.json ← Work profile config
├── work.state.json ← Work profile state
├── personal.json ← Personal profile config
└── personal.state.json ← Personal profile state
*/
// ProfileManager manages profiles for Android
// It wraps the internal profilemanager to provide Android-specific behavior
type ProfileManager struct {
configDir string
serviceMgr *profilemanager.ServiceManager
}
// NewProfileManager creates a new profile manager for Android
func NewProfileManager(configDir string) *ProfileManager {
// Set the default config path for Android (stored in root configDir, not profiles/)
defaultConfigPath := filepath.Join(configDir, defaultConfigFilename)
// Set global paths for Android
profilemanager.DefaultConfigPathDir = configDir
profilemanager.DefaultConfigPath = defaultConfigPath
profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json")
// Create ServiceManager with profiles/ subdirectory
// This avoids modifying the global ConfigDirOverride for profile listing
profilesDir := filepath.Join(configDir, profilesSubdir)
serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir)
return &ProfileManager{
configDir: configDir,
serviceMgr: serviceMgr,
}
}
// ListProfiles returns all available profiles
func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
// Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive)
internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername)
if err != nil {
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
// Convert internal profiles to Android Profile type
var profiles []*Profile
for _, p := range internalProfiles {
profiles = append(profiles, &Profile{
Name: p.Name,
IsActive: p.IsActive,
})
}
return &ProfileArray{items: profiles}, nil
}
// GetActiveProfile returns the currently active profile name
func (pm *ProfileManager) GetActiveProfile() (string, error) {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
activeState, err := pm.serviceMgr.GetActiveProfileState()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return activeState.Name, nil
}
// SwitchProfile switches to a different profile
func (pm *ProfileManager) SwitchProfile(profileName string) error {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profileName,
Username: androidUsername,
})
if err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
log.Infof("switched to profile: %s", profileName)
return nil
}
// AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory)
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to add profile: %w", err)
}
log.Infof("created new profile: %s", profileName)
return nil
}
// LogoutProfile logs out from a profile (clears authentication)
func (pm *ProfileManager) LogoutProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
configPath, err := pm.getProfileConfigPath(profileName)
if err != nil {
return err
}
// Check if profile exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("profile '%s' does not exist", profileName)
}
// Read current config using internal profilemanager
config, err := profilemanager.ReadConfig(configPath)
if err != nil {
return fmt.Errorf("failed to read profile config: %w", err)
}
// Clear authentication by removing private key and SSH key
config.PrivateKey = ""
config.SSHKey = ""
// Save config using internal profilemanager
if err := profilemanager.WriteOutConfig(configPath, config); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
log.Infof("logged out from profile: %s", profileName)
return nil
}
// RemoveProfile deletes a profile
func (pm *ProfileManager) RemoveProfile(profileName string) error {
// Use ServiceManager (removes profile from profiles/ directory)
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to remove profile: %w", err)
}
log.Infof("removed profile: %s", profileName)
return nil
}
// getProfileConfigPath returns the config file path for a profile
// This is needed for Android-specific path handling (netbird.cfg for default profile)
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
// Android uses netbird.cfg for default profile instead of default.json
// Default profile is stored in root configDir, not in profiles/
return filepath.Join(pm.configDir, defaultConfigFilename), nil
}
// Non-default profiles are stored in profiles subdirectory
// This matches the Java Preferences.java expectation
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".json"), nil
}
// GetConfigPath returns the config file path for a given profile
// Java should call this instead of constructing paths with Preferences.configFile()
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
return pm.getProfileConfigPath(profileName)
}
// GetStateFilePath returns the state file path for a given profile
// Java should call this instead of constructing paths with Preferences.stateFile()
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
return filepath.Join(pm.configDir, "state.json"), nil
}
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".state.json"), nil
}
// GetActiveConfigPath returns the config file path for the currently active profile
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile()
func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
activeProfile, err := pm.GetActiveProfile()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetConfigPath(activeProfile)
}
// GetActiveStateFilePath returns the state file path for the currently active profile
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile()
func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
activeProfile, err := pm.GetActiveProfile()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetStateFilePath(activeProfile)
}
// sanitizeProfileName removes invalid characters from profile name
func sanitizeProfileName(name string) string {
// Keep only alphanumeric, underscore, and hyphen
var result strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}

View File

@@ -115,24 +115,6 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
loginRequest.OptionalPreSharedKey = &preSharedKey
}
// set the new config
cfg, err := client.GetConfig(ctx, &proto.GetConfigRequest{
ProfileName: activeProf.Name,
Username: username,
})
if err != nil {
return fmt.Errorf("get config from daemon: %v", err)
}
req := setupSetConfigReqForLogin(cfg, activeProf.Name, username)
if _, err := client.SetConfig(ctx, req); err != nil {
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
log.Warnf("setConfig method is not available in the daemon")
} else {
return fmt.Errorf("call service setConfig method: %v", err)
}
}
var loginErr error
var loginResp *proto.LoginResponse
@@ -416,34 +398,3 @@ func setEnvAndFlags(cmd *cobra.Command) error {
return nil
}
func setupSetConfigReqForLogin(cfg *proto.GetConfigResponse, profileName, username string) *proto.SetConfigRequest {
var req proto.SetConfigRequest
req.ProfileName = profileName
req.Username = username
req.ManagementUrl = managementURL
req.AdminURL = adminURL
req.RosenpassEnabled = &cfg.RosenpassEnabled
req.RosenpassPermissive = &cfg.RosenpassPermissive
req.DisableAutoConnect = &cfg.DisableAutoConnect
req.ServerSSHAllowed = &cfg.ServerSSHAllowed
req.NetworkMonitor = &cfg.NetworkMonitor
req.DisableClientRoutes = &cfg.DisableClientRoutes
req.DisableServerRoutes = &cfg.DisableServerRoutes
req.DisableDns = &cfg.DisableDns
req.DisableFirewall = &cfg.DisableFirewall
req.BlockLanAccess = &cfg.BlockLanAccess
req.DisableNotifications = &cfg.DisableNotifications
req.LazyConnectionEnabled = &cfg.LazyConnectionEnabled
req.BlockInbound = &cfg.BlockInbound
req.DisableSSHAuth = &cfg.DisableSSHAuth
req.EnableSSHRoot = &cfg.EnableSSHRoot
req.EnableSSHSFTP = &cfg.EnableSSHSFTP
req.EnableSSHLocalPortForwarding = &cfg.EnableSSHLocalPortForwarding
req.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding
req.SshJWTCacheTTL = &cfg.SshJWTCacheTTL
return &req
}

View File

@@ -85,9 +85,6 @@ var (
// Execute executes the root command.
func Execute() error {
if isUpdateBinary() {
return updateCmd.Execute()
}
return rootCmd.Execute()
}

View File

@@ -1,176 +0,0 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
bundlePubKeysRootPrivKeyFile string
bundlePubKeysPubKeyFiles []string
bundlePubKeysFile string
createArtifactKeyRootPrivKeyFile string
createArtifactKeyPrivKeyFile string
createArtifactKeyPubKeyFile string
createArtifactKeyExpiration time.Duration
)
var createArtifactKeyCmd = &cobra.Command{
Use: "create-artifact-key",
Short: "Create a new artifact signing key",
Long: `Generate a new artifact signing key pair signed by the root private key.
The artifact key will be used to sign software artifacts/updates.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if createArtifactKeyExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil {
return fmt.Errorf("failed to create artifact key: %w", err)
}
return nil
},
}
var bundlePubKeysCmd = &cobra.Command{
Use: "bundle-pub-keys",
Short: "Bundle multiple artifact public keys into a signed package",
Long: `Bundle one or more artifact public keys into a signed package using the root private key.
This command is typically used to distribute or authorize a set of valid artifact signing keys.`,
RunE: func(cmd *cobra.Command, args []string) error {
if len(bundlePubKeysPubKeyFiles) == 0 {
return fmt.Errorf("at least one --artifact-pub-key-file must be provided")
}
if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil {
return fmt.Errorf("failed to bundle public keys: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createArtifactKeyCmd)
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved")
createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)")
if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(fmt.Errorf("mark expiration as required: %w", err))
}
rootCmd.AddCommand(bundlePubKeysCmd)
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle")
bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)")
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved")
if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil {
panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err))
}
}
func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error {
cmd.Println("Creating new artifact signing key...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration)
if err != nil {
return fmt.Errorf("generate artifact key: %w", err)
}
if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err)
}
if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err)
}
signatureFile := artifactPubKeyFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Artifact key created successfully.\n")
cmd.Printf("%s\n", artifactKey.String())
return nil
}
func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error {
cmd.Println("📦 Bundling public keys into signed package...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles))
for _, pubFile := range artifactPubKeyFiles {
pubPem, err := os.ReadFile(pubFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
pk, err := reposign.ParseArtifactPubKey(pubPem)
if err != nil {
return fmt.Errorf("failed to parse artifact key: %w", err)
}
publicKeys = append(publicKeys, pk)
}
parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys)
if err != nil {
return fmt.Errorf("bundle artifact keys: %w", err)
}
if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil {
return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err)
}
signatureFile := bundlePubKeysFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles))
return nil
}

View File

@@ -1,276 +0,0 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY"
)
var (
signArtifactPrivKeyFile string
signArtifactArtifactFile string
verifyArtifactPubKeyFile string
verifyArtifactFile string
verifyArtifactSignatureFile string
verifyArtifactKeyPubKeyFile string
verifyArtifactKeyRootPubKeyFile string
verifyArtifactKeySignatureFile string
verifyArtifactKeyRevocationFile string
)
var signArtifactCmd = &cobra.Command{
Use: "sign-artifact",
Short: "Sign an artifact using an artifact private key",
Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key.
This command produces a detached signature that can be verified using the corresponding artifact public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil {
return fmt.Errorf("failed to sign artifact: %w", err)
}
return nil
},
}
var verifyArtifactCmd = &cobra.Command{
Use: "verify-artifact",
Short: "Verify an artifact signature using an artifact public key",
Long: `Verify a software artifact signature using the artifact's public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil {
return fmt.Errorf("failed to verify artifact: %w", err)
}
return nil
},
}
var verifyArtifactKeyCmd = &cobra.Command{
Use: "verify-artifact-key",
Short: "Verify an artifact public key was signed by a root key",
Long: `Verify that an artifact public key (or bundle) was properly signed by a root key.
This validates the chain of trust from the root key to the artifact key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil {
return fmt.Errorf("failed to verify artifact key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(signArtifactCmd)
rootCmd.AddCommand(verifyArtifactCmd)
rootCmd.AddCommand(verifyArtifactKeyCmd)
signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey))
signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed")
// artifact-file is required, but artifact-key-file can come from env var
if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file")
if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)")
if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil {
panic(fmt.Errorf("mark root-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
}
func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error {
cmd.Println("🖋️ Signing artifact...")
// Load private key from env var or file
var privKeyPEM []byte
var err error
if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" {
// Use key from environment variable
privKeyPEM = []byte(envKey)
} else if privKeyFile != "" {
// Fall back to file
privKeyPEM, err = os.ReadFile(privKeyFile)
if err != nil {
return fmt.Errorf("read private key file: %w", err)
}
} else {
return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey)
}
privateKey, err := reposign.ParseArtifactKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact private key: %w", err)
}
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
signature, err := reposign.SignData(privateKey, artifactData)
if err != nil {
return fmt.Errorf("sign artifact: %w", err)
}
sigFile := artifactFile + ".sig"
if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", sigFile, err)
}
cmd.Printf("✅ Artifact signed successfully.\n")
cmd.Printf("Signature file: %s\n", sigFile)
return nil
}
func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error {
cmd.Println("🔍 Verifying artifact...")
// Read artifact public key
pubKeyPEM, err := os.ReadFile(pubKeyFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact public key: %w", err)
}
// Read artifact data
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate artifact
if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil {
return fmt.Errorf("artifact verification failed: %w", err)
}
cmd.Println("✅ Artifact signature is valid")
cmd.Printf("Artifact: %s\n", artifactFile)
cmd.Printf("Signed by key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
return nil
}
func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error {
cmd.Println("🔍 Verifying artifact key...")
// Read artifact key data
artifactKeyData, err := os.ReadFile(artifactKeyFile)
if err != nil {
return fmt.Errorf("read artifact key file: %w", err)
}
// Read root public key(s)
rootKeyData, err := os.ReadFile(rootKeyFile)
if err != nil {
return fmt.Errorf("read root key file: %w", err)
}
rootPublicKeys, err := parseRootPublicKeys(rootKeyData)
if err != nil {
return fmt.Errorf("failed to parse root public key(s): %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Read optional revocation list
var revocationList *reposign.RevocationList
if revocationFile != "" {
revData, err := os.ReadFile(revocationFile)
if err != nil {
return fmt.Errorf("read revocation file: %w", err)
}
revocationList, err = reposign.ParseRevocationList(revData)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
}
// Validate artifact key(s)
validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList)
if err != nil {
return fmt.Errorf("artifact key verification failed: %w", err)
}
cmd.Println("✅ Artifact key(s) verified successfully")
cmd.Printf("Signed by root key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys))
for i, key := range validKeys {
cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID)
cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST"))
if !key.Metadata.ExpiresAt.IsZero() {
cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST"))
} else {
cmd.Printf(" Expires: Never\n")
}
}
return nil
}
// parseRootPublicKeys parses a root public key from PEM data
func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) {
key, err := reposign.ParseRootPublicKey(data)
if err != nil {
return nil, err
}
return []reposign.PublicKey{key}, nil
}

View File

@@ -1,21 +0,0 @@
package main
import (
"os"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "signer",
Short: "A CLI tool for managing cryptographic keys and artifacts",
Long: `signer is a command-line tool that helps you manage
root keys, artifact keys, and revocation lists securely.`,
}
func main() {
if err := rootCmd.Execute(); err != nil {
rootCmd.Println(err)
os.Exit(1)
}
}

View File

@@ -1,220 +0,0 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year
)
var (
keyID string
revocationListFile string
privateRootKeyFile string
publicRootKeyFile string
signatureFile string
expirationDuration time.Duration
)
var createRevocationListCmd = &cobra.Command{
Use: "create-revocation-list",
Short: "Create a new revocation list signed by the private root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile)
},
}
var extendRevocationListCmd = &cobra.Command{
Use: "extend-revocation-list",
Short: "Extend an existing revocation list with a given key ID",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile)
},
}
var verifyRevocationListCmd = &cobra.Command{
Use: "verify-revocation-list",
Short: "Verify a revocation list signature using the public root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile)
},
}
func init() {
rootCmd.AddCommand(createRevocationListCmd)
rootCmd.AddCommand(extendRevocationListCmd)
rootCmd.AddCommand(verifyRevocationListCmd)
createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for")
extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file")
verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file")
verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file")
if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil {
panic(err)
}
}
func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration)
if err != nil {
return fmt.Errorf("failed to create revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list created successfully")
return nil
}
func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
rl, err := reposign.ParseRevocationList(rlBytes)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
kid, err := reposign.ParseKeyID(keyID)
if err != nil {
return fmt.Errorf("invalid key ID: %w", err)
}
newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration)
if err != nil {
return fmt.Errorf("failed to extend revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list extended successfully")
return nil
}
func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error {
// Read revocation list file
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
// Read signature file
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("failed to read signature file: %w", err)
}
// Read public root key file
pubKeyPEM, err := os.ReadFile(publicRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read public root key file: %w", err)
}
// Parse public root key
publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse public root key: %w", err)
}
// Parse signature
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate revocation list
rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature)
if err != nil {
return fmt.Errorf("failed to validate revocation list: %w", err)
}
// Display results
cmd.Println("✅ Revocation list signature is valid")
cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339))
cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339))
cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked))
if len(rl.Revoked) > 0 {
cmd.Println("\nRevoked Keys:")
for keyID, revokedTime := range rl.Revoked {
cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339))
}
}
return nil
}
func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error {
if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil {
return fmt.Errorf("failed to write revocation list file: %w", err)
}
if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil {
return fmt.Errorf("failed to write signature file: %w", err)
}
return nil
}

View File

@@ -1,74 +0,0 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
privKeyFile string
pubKeyFile string
rootExpiration time.Duration
)
var createRootKeyCmd = &cobra.Command{
Use: "create-root-key",
Short: "Create a new root key pair",
Long: `Create a new root key pair and specify an expiration time for it.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
// Validate expiration
if rootExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
// Run main logic
if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil {
return fmt.Errorf("failed to generate root key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createRootKeyCmd)
createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file")
createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file")
createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)")
if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(err)
}
}
func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error {
rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration)
if err != nil {
return fmt.Errorf("generate root key: %w", err)
}
// Write private key
if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", privKeyFile, err)
}
// Write public key
if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err)
}
cmd.Printf("%s\n\n", rk.String())
cmd.Printf("✅ Root key pair generated successfully.\n")
return nil
}

View File

@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r, false)
connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)

View File

@@ -1,13 +0,0 @@
//go:build !windows && !darwin
package cmd
import (
"github.com/spf13/cobra"
)
var updateCmd *cobra.Command
func isUpdateBinary() bool {
return false
}

View File

@@ -1,75 +0,0 @@
//go:build windows || darwin
package cmd
import (
"context"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/util"
)
var (
updateCmd = &cobra.Command{
Use: "update",
Short: "Update the NetBird client application",
RunE: updateFunc,
}
tempDirFlag string
installerFile string
serviceDirFlag string
dryRunFlag bool
)
func init() {
updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir")
updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file")
updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory")
updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes")
}
// isUpdateBinary checks if the current executable is named "update" or "update.exe"
func isUpdateBinary() bool {
// Remove extension for cross-platform compatibility
execPath, err := os.Executable()
if err != nil {
return false
}
baseName := filepath.Base(execPath)
name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
return name == installer.UpdaterBinaryNameWithoutExtension()
}
func updateFunc(cmd *cobra.Command, args []string) error {
if err := setupLogToFile(tempDirFlag); err != nil {
return err
}
log.Infof("updater started: %s", serviceDirFlag)
updater := installer.NewWithDir(tempDirFlag)
if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil {
log.Errorf("failed to update application: %v", err)
return err
}
return nil
}
func setupLogToFile(dir string) error {
logFile := filepath.Join(dir, installer.LogFile)
if _, err := os.Stat(logFile); err == nil {
if err := os.Remove(logFile); err != nil {
log.Errorf("failed to remove existing log file: %v\n", err)
}
}
return util.InitLog(logLevel, util.LogConsole, logFile)
}

View File

@@ -173,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client := internal.NewConnectClient(ctx, c.config, recorder)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available

View File

@@ -24,14 +24,10 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -43,13 +39,11 @@ import (
)
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
doInitialAutoUpdate bool
engine *Engine
engineMutex sync.Mutex
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
persistSyncResponse bool
}
@@ -58,15 +52,13 @@ func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
doInitialAutoUpdate: doInitalAutoUpdate,
engineMutex: sync.Mutex{},
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
}
}
@@ -170,33 +162,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return err
}
var path string
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
// On mobile, use the provided state file path directly
if !fileExists(mobileDependency.StateFilePath) {
if err := createFile(mobileDependency.StateFilePath); err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
path = mobileDependency.StateFilePath
} else {
sm := profilemanager.NewServiceManager("")
path = sm.GetStatePath()
}
stateManager := statemanager.New(path)
stateManager.RegisterState(&sshconfig.ShutdownState{})
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
if err == nil {
updateManager.CheckUpdateSuccess(c.ctx)
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
}
}
defer c.statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
@@ -308,7 +273,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
@@ -318,15 +283,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
// AutoUpdate will be true when the user click on "Connect" menu on the UI
if c.doInitialAutoUpdate {
log.Infof("start engine by ui, run auto-update check")
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
c.doInitialAutoUpdate = false
}
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)

View File

@@ -27,7 +27,6 @@ import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
@@ -57,7 +56,6 @@ block.prof: Block profiling information.
heap.prof: Heap profiling information (snapshot of memory allocations).
allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
Anonymization Process
@@ -111,9 +109,6 @@ go tool pprof -http=:8088 heap.prof
This will open a web browser tab with the profiling information.
Stack Trace
The stack_trace.txt file contains a complete snapshot of all goroutine stack traces at the time the debug bundle was created.
Routes
The routes.txt file contains detailed routing table information in a tabular format:
@@ -332,10 +327,6 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add profiles to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
if err := g.addSyncResponse(); err != nil {
return fmt.Errorf("add sync response: %w", err)
}
@@ -363,10 +354,6 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add systemd logs: %v", err)
}
if err := g.addUpdateLogs(); err != nil {
log.Errorf("failed to add updater logs: %v", err)
}
return nil
}
@@ -535,18 +522,6 @@ func (g *BundleGenerator) addProf() (err error) {
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)
stackTrace := bytes.NewReader(buf[:n])
if err := g.addFileToZip(stackTrace, "stack_trace.txt"); err != nil {
return fmt.Errorf("add stack trace file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addInterfaces() error {
interfaces, err := net.Interfaces()
if err != nil {
@@ -655,29 +630,6 @@ func (g *BundleGenerator) addStateFile() error {
return nil
}
func (g *BundleGenerator) addUpdateLogs() error {
inst := installer.New()
logFiles := inst.LogFiles()
if len(logFiles) == 0 {
return nil
}
log.Infof("adding updater logs")
for _, logFile := range logFiles {
data, err := os.ReadFile(logFile)
if err != nil {
log.Warnf("failed to read update log file %s: %v", logFile, err)
continue
}
baseName := filepath.Base(logFile)
if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil {
return fmt.Errorf("add update log file %s to zip: %w", baseName, err)
}
}
return nil
}
func (g *BundleGenerator) addCorruptedStateFiles() error {
sm := profilemanager.NewServiceManager("")
pattern := sm.GetStatePath()

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"net/netip"
"net/url"
"strings"
"sync"
@@ -27,11 +26,6 @@ type Resolver struct {
mutex sync.RWMutex
}
type ipsResponse struct {
ips []netip.Addr
err error
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
@@ -105,9 +99,9 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
ips, err := lookupIPWithExtraTimeout(ctx, d)
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
if err != nil {
return err
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
}
var aRecords, aaaaRecords []dns.RR
@@ -165,36 +159,6 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
return nil
}
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
resultChan := make(chan *ipsResponse, 1)
go func() {
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
resultChan <- &ipsResponse{
err: err,
ips: ips,
}
}()
var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
}
if resp.err != nil {
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
}
return resp.ips, nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
if mgmtURL == nil {

View File

@@ -80,7 +80,6 @@ type DefaultServer struct {
updateSerial uint64
previousConfigHash uint64
currentConfig HostDNSConfig
currentConfigHash uint64
handlerChain *HandlerChain
extraDomains map[domain.Domain]int
@@ -208,7 +207,6 @@ func newDefaultServer(
hostsDNSHolder: newHostsDNSHolder(),
hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver,
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
}
// register with root zone, handler chain takes care of the routing
@@ -588,29 +586,8 @@ func (s *DefaultServer) applyHostConfig() {
log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
hash, err := hashstructure.Hash(config, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
if err != nil {
log.Warnf("unable to hash the host dns configuration, will apply config anyway: %s", err)
// Fall through to apply config anyway (fail-safe approach)
} else if s.currentConfigHash == hash {
log.Debugf("not applying host config as there are no changes")
return
}
log.Debugf("applying host config as there are changes")
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err)
return
}
// Only update hash if it was computed successfully and config was applied
if err == nil {
s.currentConfigHash = hash
}
s.registerFallback(config)

View File

@@ -1602,10 +1602,7 @@ func TestExtraDomains(t *testing.T) {
"other.example.com.",
"duplicate.example.com.",
},
// Expect 3 calls instead of 4 because when deregistering duplicate.example.com,
// the domain remains in the config (ref count goes from 2 to 1), so the host
// config hash doesn't change and applyDNSConfig is not called.
applyHostConfigCall: 3,
applyHostConfigCall: 4,
},
{
name: "Config update with new domains after registration",
@@ -1660,10 +1657,7 @@ func TestExtraDomains(t *testing.T) {
expectedMatchOnly: []string{
"extra.example.com.",
},
// Expect 2 calls instead of 3 because when deregistering protected.example.com,
// it's removed from extraDomains but still remains in the config (from customZones),
// so the host config hash doesn't change and applyDNSConfig is not called.
applyHostConfigCall: 2,
applyHostConfigCall: 3,
},
{
name: "Register domain that is part of nameserver group",

View File

@@ -42,13 +42,14 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager"
cProto "github.com/netbirdio/netbird/client/proto"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -72,7 +73,6 @@ const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
disableAutoUpdate = "disabled"
)
var ErrResetConnection = fmt.Errorf("reset connection")
@@ -201,9 +201,6 @@ type Engine struct {
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
// auto-update
updateManager *updatemanager.Manager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
@@ -224,7 +221,17 @@ type localIpUpdater interface {
}
// NewEngine creates a new Connection Engine with probes attached
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
@@ -240,12 +247,28 @@ func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signa
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: statusRecorder,
stateManager: stateManager,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
}
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
if !fileExists(mobileDep.StateFilePath) {
err := createFile(mobileDep.StateFilePath)
if err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
path = mobileDep.StateFilePath
}
engine.stateManager = statemanager.New(path)
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine
}
@@ -285,10 +308,6 @@ func (e *Engine) Stop() error {
e.srWatcher.Close()
}
if e.updateManager != nil {
e.updateManager.Stop()
}
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
@@ -522,13 +541,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return nil
}
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.handleAutoUpdateVersion(autoUpdateSettings, true)
}
func (e *Engine) createFirewall() error {
if e.config.DisableFirewall {
log.Infof("firewall is disabled")
@@ -737,41 +749,6 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
return nil
}
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
if autoUpdateSettings == nil {
return
}
disabled := autoUpdateSettings.Version == disableAutoUpdate
// Stop and cleanup if disabled
if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop()
e.updateManager = nil
return
}
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
return
}
// Start manager if needed
if e.updateManager == nil {
log.Infof("starting auto-update manager")
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
if err != nil {
return
}
e.updateManager = updateManager
e.updateManager.Start(e.ctx)
}
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
e.updateManager.SetVersion(autoUpdateSettings.Version)
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -781,10 +758,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return e.ctx.Err()
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
@@ -824,7 +797,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
}
nm := update.GetNetworkMap()
if nm == nil {
if nm == nil || update.SkipNetworkMapUpdate {
return nil
}
@@ -990,7 +963,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.DisableSSHAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.networkSerial, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@@ -1151,8 +1124,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store

View File

@@ -11,18 +11,15 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetStatus() (bool, []sshserver.SessionInfo)
UpdateSSHAuth(config *sshauth.Config)
}
func (e *Engine) setupSSHPortRedirection() error {
@@ -356,38 +353,3 @@ func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.Sessio
return sshServer.GetStatus()
}
// updateSSHServerAuth updates SSH fine-grained access control configuration on a running SSH server
func (e *Engine) updateSSHServerAuth(sshAuth *mgmProto.SSHAuth) {
if sshAuth == nil {
return
}
if e.sshServer == nil {
return
}
protoUsers := sshAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid hash length %d, expected 16 - skipping SSH server auth update", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range sshAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
// Update SSH server with new authorization configuration
authConfig := &sshauth.Config{
UserIDClaim: sshAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
}
e.sshServer.UpdateSSHAuth(authConfig)
}

View File

@@ -0,0 +1,79 @@
package internal
import (
"context"
"testing"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
// Ensures handleSync exits early when SkipNetworkMapUpdate is true
func TestEngine_HandleSync_SkipNetworkMapUpdate(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
WgIfaceName: "utun199",
WgAddr: "100.70.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
// Precondition
if engine.networkSerial != 0 {
t.Fatalf("unexpected initial serial: %d", engine.networkSerial)
}
resp := &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{Serial: 42},
SkipNetworkMapUpdate: true,
}
if err := engine.handleSync(resp); err != nil {
t.Fatalf("handleSync returned error: %v", err)
}
if engine.networkSerial != 0 {
t.Fatalf("networkSerial changed despite SkipNetworkMapUpdate; got %d, want 0", engine.networkSerial)
}
}
// Ensures handleSync exits early when NetworkMap is nil
func TestEngine_HandleSync_NilNetworkMap(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, nil, &client.MockClient{}, nil, &EngineConfig{
WgIfaceName: "utun198",
WgAddr: "100.70.0.2/24",
WgPrivateKey: key,
WgPort: 33101,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
resp := &mgmtProto.SyncResponse{NetworkMap: nil}
if err := engine.handleSync(resp); err != nil {
t.Fatalf("handleSync returned error: %v", err)
}
}

View File

@@ -253,7 +253,6 @@ func TestEngine_SSH(t *testing.T) {
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil,
nil,
)
engine.dnsServer = &dns.MockServer{
@@ -415,13 +414,21 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
relayMgr,
&EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil)
wgIface := &MockWGIface{
NameFunc: func() string { return "utun102" },
@@ -624,7 +631,7 @@ func TestEngine_Sync(t *testing.T) {
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
syncFunc := func(ctx context.Context, info *system.Info, networkSerial uint64, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
@@ -640,7 +647,7 @@ func TestEngine_Sync(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
@@ -805,7 +812,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
@@ -1007,7 +1014,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
@@ -1533,7 +1540,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx
return e, err
}

View File

@@ -20,7 +20,7 @@ type EndpointUpdater struct {
wgConfig WgConfig
initiator bool
// mu protects cancelFunc
// mu protects updateWireGuardPeer and cancelFunc
mu sync.Mutex
cancelFunc func()
updateWg sync.WaitGroup
@@ -86,9 +86,11 @@ func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.U
case <-ctx.Done():
return
case <-t.C:
e.mu.Lock()
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
e.mu.Unlock()
}
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"net/url"
"os"
"os/user"
"path/filepath"
"reflect"
"runtime"
@@ -166,26 +165,19 @@ func getConfigDir() (string, error) {
if ConfigDirOverride != "" {
return ConfigDirOverride, nil
}
base, err := baseConfigDir()
configDir, err := os.UserConfigDir()
if err != nil {
return "", err
}
configDir := filepath.Join(base, "netbird")
if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err
}
return configDir, nil
}
func baseConfigDir() (string, error) {
if runtime.GOOS == "darwin" {
if u, err := user.Current(); err == nil && u.HomeDir != "" {
return filepath.Join(u.HomeDir, "Library", "Application Support"), nil
configDir = filepath.Join(configDir, "netbird")
if _, err := os.Stat(configDir); os.IsNotExist(err) {
if err := os.MkdirAll(configDir, 0755); err != nil {
return "", err
}
}
return os.UserConfigDir()
return configDir, nil
}
func getConfigDirForUser(username string) (string, error) {

View File

@@ -76,7 +76,6 @@ func (a *ActiveProfileState) FilePath() (string, error) {
}
type ServiceManager struct {
profilesDir string // If set, overrides ConfigDirOverride for profile operations
}
func NewServiceManager(defaultConfigPath string) *ServiceManager {
@@ -86,17 +85,6 @@ func NewServiceManager(defaultConfigPath string) *ServiceManager {
return &ServiceManager{}
}
// NewServiceManagerWithProfilesDir creates a ServiceManager with a specific profiles directory
// This allows setting the profiles directory without modifying the global ConfigDirOverride
func NewServiceManagerWithProfilesDir(defaultConfigPath string, profilesDir string) *ServiceManager {
if defaultConfigPath != "" {
DefaultConfigPath = defaultConfigPath
}
return &ServiceManager{
profilesDir: profilesDir,
}
}
func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil {
@@ -252,7 +240,7 @@ func (s *ServiceManager) DefaultProfilePath() string {
}
func (s *ServiceManager) AddProfile(profileName, username string) error {
configDir, err := s.getConfigDir(username)
configDir, err := getConfigDirForUser(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
}
@@ -282,7 +270,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
}
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
configDir, err := s.getConfigDir(username)
configDir, err := getConfigDirForUser(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
}
@@ -314,7 +302,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
}
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
configDir, err := s.getConfigDir(username)
configDir, err := getConfigDirForUser(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
@@ -373,7 +361,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
configDir, err := s.getConfigDir(activeProf.Username)
configDir, err := getConfigDirForUser(activeProf.Username)
if err != nil {
log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err)
return defaultStatePath
@@ -381,12 +369,3 @@ func (s *ServiceManager) GetStatePath() string {
return filepath.Join(configDir, activeProf.Name+".state.json")
}
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
func (s *ServiceManager) getConfigDir(username string) (string, error) {
if s.profilesDir != "" {
return s.profilesDir, nil
}
return getConfigDirForUser(username)
}

View File

@@ -1,35 +0,0 @@
// Package updatemanager provides automatic update management for the NetBird client.
// It monitors for new versions, handles update triggers from management server directives,
// and orchestrates the download and installation of client updates.
//
// # Overview
//
// The update manager operates as a background service that continuously monitors for
// available updates and automatically initiates the update process when conditions are met.
// It integrates with the installer package to perform the actual installation.
//
// # Update Flow
//
// The complete update process follows these steps:
//
// 1. Manager receives update directive via SetVersion() or detects new version
// 2. Manager validates update should proceed (version comparison, rate limiting)
// 3. Manager publishes "updating" event to status recorder
// 4. Manager persists UpdateState to track update attempt
// 5. Manager downloads installer file (.msi or .exe) to temporary directory
// 6. Manager triggers installation via installer.RunInstallation()
// 7. Installer package handles the actual installation process
// 8. On next startup, CheckUpdateSuccess() verifies update completion
// 9. Manager publishes success/failure event to status recorder
// 10. Manager cleans up UpdateState
//
// # State Management
//
// Update state is persisted across restarts to track update attempts:
//
// - PreUpdateVersion: Version before update attempt
// - TargetVersion: Version attempting to update to
//
// This enables verification of successful updates and appropriate user notification
// after the client restarts with the new version.
package updatemanager

View File

@@ -1,138 +0,0 @@
package downloader
import (
"context"
"fmt"
"io"
"net/http"
"os"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
)
const (
userAgent = "NetBird agent installer/%s"
DefaultRetryDelay = 3 * time.Second
)
func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error {
log.Debugf("starting download from %s", url)
out, err := os.Create(dstFile)
if err != nil {
return fmt.Errorf("failed to create destination file %q: %w", dstFile, err)
}
defer func() {
if cerr := out.Close(); cerr != nil {
log.Warnf("error closing file %q: %v", dstFile, cerr)
}
}()
// First attempt
err = downloadToFileOnce(ctx, url, out)
if err == nil {
log.Infof("successfully downloaded file to %s", dstFile)
return nil
}
// If retryDelay is 0, don't retry
if retryDelay == 0 {
return err
}
log.Warnf("download failed, retrying after %v: %v", retryDelay, err)
// Sleep before retry
if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil {
return fmt.Errorf("download cancelled during retry delay: %w", sleepErr)
}
// Truncate file before retry
if err := out.Truncate(0); err != nil {
return fmt.Errorf("failed to truncate file on retry: %w", err)
}
if _, err := out.Seek(0, 0); err != nil {
return fmt.Errorf("failed to seek to beginning of file: %w", err)
}
// Second attempt
if err := downloadToFileOnce(ctx, url, out); err != nil {
return fmt.Errorf("download failed after retry: %w", err)
}
log.Infof("successfully downloaded file to %s", dstFile)
return nil
}
func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
// Add User-Agent header
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to perform HTTP request: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
log.Warnf("error closing response body: %v", cerr)
}
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, limit))
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return data, nil
}
func downloadToFileOnce(ctx context.Context, url string, out *os.File) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("failed to create HTTP request: %w", err)
}
// Add User-Agent header
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to perform HTTP request: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
log.Warnf("error closing response body: %v", cerr)
}
}()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
}
if _, err := io.Copy(out, resp.Body); err != nil {
return fmt.Errorf("failed to write response body to file: %w", err)
}
return nil
}
func sleepWithContext(ctx context.Context, duration time.Duration) error {
select {
case <-time.After(duration):
return nil
case <-ctx.Done():
return ctx.Err()
}
}

View File

@@ -1,199 +0,0 @@
package downloader
import (
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
)
const (
retryDelay = 100 * time.Millisecond
)
func TestDownloadToFile_Success(t *testing.T) {
// Create a test server that responds successfully
content := "test file content"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(content))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file
err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Verify the file content
data, err := os.ReadFile(dstFile)
if err != nil {
t.Fatalf("failed to read downloaded file: %v", err)
}
if string(data) != content {
t.Errorf("expected content %q, got %q", content, string(data))
}
}
func TestDownloadToFile_SuccessAfterRetry(t *testing.T) {
content := "test file content after retry"
var attemptCount atomic.Int32
// Create a test server that fails on first attempt, succeeds on second
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := attemptCount.Add(1)
if attempt == 1 {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(content))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file (should succeed after retry)
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil {
t.Fatalf("expected no error after retry, got: %v", err)
}
// Verify the file content
data, err := os.ReadFile(dstFile)
if err != nil {
t.Fatalf("failed to read downloaded file: %v", err)
}
if string(data) != content {
t.Errorf("expected content %q, got %q", content, string(data))
}
// Verify it took 2 attempts
if attemptCount.Load() != 2 {
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_FailsAfterRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file (should fail after retry)
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil {
t.Fatal("expected error after retry, got nil")
}
// Verify it tried 2 times
if attemptCount.Load() != 2 {
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Create a context that will be cancelled during retry delay
ctx, cancel := context.WithCancel(context.Background())
// Cancel after a short delay (during the retry sleep)
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
// Download the file (should fail due to context cancellation during retry)
err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile)
if err == nil {
t.Fatal("expected error due to context cancellation, got nil")
}
// Should have only made 1 attempt (cancelled during retry delay)
if attemptCount.Load() != 1 {
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_InvalidURL(t *testing.T) {
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile)
if err == nil {
t.Fatal("expected error for invalid URL, got nil")
}
}
func TestDownloadToFile_InvalidDestination(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("test"))
}))
defer server.Close()
// Use an invalid destination path
err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt")
if err == nil {
t.Fatal("expected error for invalid destination, got nil")
}
}
func TestDownloadToFile_NoRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file with retryDelay = 0 (should not retry)
if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil {
t.Fatal("expected error, got nil")
}
// Verify it only made 1 attempt (no retry)
if attemptCount.Load() != 1 {
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
}
}

View File

@@ -1,7 +0,0 @@
//go:build !windows
package installer
func UpdaterBinaryNameWithoutExtension() string {
return updaterBinary
}

View File

@@ -1,11 +0,0 @@
package installer
import (
"path/filepath"
"strings"
)
func UpdaterBinaryNameWithoutExtension() string {
ext := filepath.Ext(updaterBinary)
return strings.TrimSuffix(updaterBinary, ext)
}

View File

@@ -1,111 +0,0 @@
// Package installer provides functionality for managing NetBird application
// updates and installations across Windows, macOS. It handles
// the complete update lifecycle including artifact download, cryptographic verification,
// installation execution, process management, and result reporting.
//
// # Architecture
//
// The installer package uses a two-process architecture to enable self-updates:
//
// 1. Service Process: The main NetBird daemon process that initiates updates
// 2. Updater Process: A detached child process that performs the actual installation
//
// This separation is critical because:
// - The service binary cannot update itself while running
// - The installer (EXE/MSI/PKG) will terminate the service during installation
// - The updater process survives service termination and restarts it after installation
// - Results can be communicated back to the service after it restarts
//
// # Update Flow
//
// Service Process (RunInstallation):
//
// 1. Validates target version format (semver)
// 2. Determines installer type (EXE, MSI, PKG, or Homebrew)
// 3. Downloads installer file from GitHub releases (if applicable)
// 4. Verifies installer signature using reposign package (cryptographic verification in service process before
// launching updater)
// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows)
// 6. Launches updater process with detached mode:
// - --temp-dir: Temporary directory path
// - --service-dir: Service installation directory
// - --installer-file: Path to downloaded installer (if applicable)
// - --dry-run: Optional flag to test without actually installing
// 7. Service process continues running (will be terminated by installer later)
// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion
//
// Updater Process (Setup):
//
// 1. Receives parameters from service via command-line arguments
// 2. Runs installer with appropriate silent/quiet flags:
// - Windows EXE: installer.exe /S
// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log
// - macOS PKG: installer -pkg installer.pkg -target /
// - macOS Homebrew: brew upgrade netbirdio/tap/netbird
// 3. Installer terminates daemon and UI processes
// 4. Installer replaces binaries with new version
// 5. Updater waits for installer to complete
// 6. Updater restarts daemon:
// - Windows: netbird.exe service start
// - macOS/Linux: netbird service start
// 7. Updater restarts UI:
// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser
// - macOS: Uses launchctl asuser to launch NetBird.app for console user
// - Linux: Not implemented (UI typically auto-starts)
// 8. Updater writes result.json with success/error status
// 9. Updater process exits
//
// # Result Communication
//
// The ResultHandler (result.go) manages communication between updater and service:
//
// Result Structure:
//
// type Result struct {
// Success bool // true if installation succeeded
// Error string // error message if Success is false
// ExecutedAt time.Time // when installation completed
// }
//
// Result files are automatically cleaned up after being read.
//
// # File Locations
//
// Temporary Directory (platform-specific):
//
// Windows:
// - Path: %ProgramData%\Netbird\tmp-install
// - Example: C:\ProgramData\Netbird\tmp-install
//
// macOS:
// - Path: /var/lib/netbird/tmp-install
// - Requires root permissions
//
// Files created during installation:
//
// tmp-install/
// installer.log
// updater[.exe] # Copy of service binary
// netbird_installer_*.[exe|msi|pkg] # Downloaded installer
// result.json # Installation result
// msi.log # MSI verbose log (Windows MSI only)
//
// # API Reference
//
// # Cleanup
//
// CleanUpInstallerFiles() removes temporary files after successful installation:
// - Downloaded installer files (*.exe, *.msi, *.pkg)
// - Updater binary copy
// - Does NOT remove result.json (cleaned by ResultHandler after read)
// - Does NOT remove msi.log (kept for debugging)
//
// # Dry-Run Mode
//
// Dry-run mode allows testing the update process without actually installing:
//
// Enable via environment variable:
//
// export NB_AUTO_UPDATE_DRY_RUN=true
// netbird service install-update 0.29.0
package installer

View File

@@ -1,50 +0,0 @@
//go:build !windows && !darwin
package installer
import (
"context"
"fmt"
)
const (
updaterBinary = "updater"
)
type Installer struct {
tempDir string
}
// New used by the service
func New() *Installer {
return &Installer{}
}
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
func NewWithDir(tempDir string) *Installer {
return &Installer{
tempDir: tempDir,
}
}
func (u *Installer) TempDir() string {
return ""
}
func (c *Installer) LogFiles() []string {
return []string{}
}
func (u *Installer) CleanUpInstallerFiles() error {
return nil
}
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error {
return fmt.Errorf("unsupported platform")
}
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) {
return fmt.Errorf("unsupported platform")
}

View File

@@ -1,293 +0,0 @@
//go:build windows || darwin
package installer
import (
"context"
"fmt"
"io"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"github.com/hashicorp/go-multierror"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
type Installer struct {
tempDir string
}
// New used by the service
func New() *Installer {
return &Installer{
tempDir: defaultTempDir,
}
}
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
func NewWithDir(tempDir string) *Installer {
return &Installer{
tempDir: tempDir,
}
}
// RunInstallation starts the updater process to run the installation
// This will run by the original service process
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) {
resultHandler := NewResultHandler(u.tempDir)
defer func() {
if err != nil {
if writeErr := resultHandler.WriteErr(err); writeErr != nil {
log.Errorf("failed to write error result: %v", writeErr)
}
}
}()
if err := validateTargetVersion(targetVersion); err != nil {
return err
}
if err := u.mkTempDir(); err != nil {
return err
}
var installerFile string
// Download files only when not using any third-party store
if installerType := TypeOfInstaller(ctx); installerType.Downloadable() {
log.Infof("download installer")
var err error
installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion)
if err != nil {
log.Errorf("failed to download installer: %v", err)
return err
}
artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL)
if err != nil {
log.Errorf("failed to create artifact verify: %v", err)
return err
}
if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil {
log.Errorf("artifact verification error: %v", err)
return err
}
}
log.Infof("running installer")
updaterPath, err := u.copyUpdater()
if err != nil {
return err
}
// the directory where the service has been installed
workspace, err := getServiceDir()
if err != nil {
return err
}
args := []string{
"--temp-dir", u.tempDir,
"--service-dir", workspace,
}
if isDryRunEnabled() {
args = append(args, "--dry-run=true")
}
if installerFile != "" {
args = append(args, "--installer-file", installerFile)
}
updateCmd := exec.Command(updaterPath, args...)
log.Infof("starting updater process: %s", updateCmd.String())
// Configure the updater to run in a separate session/process group
// so it survives the parent daemon being stopped
setUpdaterProcAttr(updateCmd)
// Start the updater process asynchronously
if err := updateCmd.Start(); err != nil {
return err
}
pid := updateCmd.Process.Pid
log.Infof("updater started with PID %d", pid)
// Release the process so the OS can fully detach it
if err := updateCmd.Process.Release(); err != nil {
log.Warnf("failed to release updater process: %v", err)
}
return nil
}
// CleanUpInstallerFiles
// - the installer file (pkg, exe, msi)
// - the selfcopy updater.exe
func (u *Installer) CleanUpInstallerFiles() error {
// Check if tempDir exists
info, err := os.Stat(u.tempDir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
if !info.IsDir() {
return nil
}
var merr *multierror.Error
if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) {
merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err))
}
entries, err := os.ReadDir(u.tempDir)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
for _, ext := range binaryExtensions {
if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) {
if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err))
}
break
}
}
}
return merr.ErrorOrNil()
}
func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) {
fileURL := urlWithVersionArch(installerType, targetVersion)
// Clean up temp directory on error
var success bool
defer func() {
if !success {
if err := os.RemoveAll(u.tempDir); err != nil {
log.Errorf("error cleaning up temporary directory: %v", err)
}
}
}()
fileName := path.Base(fileURL)
if fileName == "." || fileName == "/" || fileName == "" {
return "", fmt.Errorf("invalid file URL: %s", fileURL)
}
outputFilePath := filepath.Join(u.tempDir, fileName)
if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil {
return "", err
}
success = true
return outputFilePath, nil
}
func (u *Installer) TempDir() string {
return u.tempDir
}
func (u *Installer) mkTempDir() error {
if err := os.MkdirAll(u.tempDir, 0o755); err != nil {
log.Debugf("failed to create tempdir: %s", u.tempDir)
return err
}
return nil
}
func (u *Installer) copyUpdater() (string, error) {
src, err := getServiceBinary()
if err != nil {
return "", fmt.Errorf("failed to get updater binary: %w", err)
}
dst := filepath.Join(u.tempDir, updaterBinary)
if err := copyFile(src, dst); err != nil {
return "", fmt.Errorf("failed to copy updater binary: %w", err)
}
if err := os.Chmod(dst, 0o755); err != nil {
return "", fmt.Errorf("failed to set permissions: %w", err)
}
return dst, nil
}
func validateTargetVersion(targetVersion string) error {
if targetVersion == "" {
return fmt.Errorf("target version cannot be empty")
}
_, err := goversion.NewVersion(targetVersion)
if err != nil {
return fmt.Errorf("invalid target version %q: %w", targetVersion, err)
}
return nil
}
func copyFile(src, dst string) error {
log.Infof("copying %s to %s", src, dst)
in, err := os.Open(src)
if err != nil {
return fmt.Errorf("open source: %w", err)
}
defer func() {
if err := in.Close(); err != nil {
log.Warnf("failed to close source file: %v", err)
}
}()
out, err := os.Create(dst)
if err != nil {
return fmt.Errorf("create destination: %w", err)
}
defer func() {
if err := out.Close(); err != nil {
log.Warnf("failed to close destination file: %v", err)
}
}()
if _, err := io.Copy(out, in); err != nil {
return fmt.Errorf("copy: %w", err)
}
return nil
}
func getServiceDir() (string, error) {
exePath, err := os.Executable()
if err != nil {
return "", err
}
return filepath.Dir(exePath), nil
}
func getServiceBinary() (string, error) {
return os.Executable()
}
func isDryRunEnabled() bool {
return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true")
}

View File

@@ -1,11 +0,0 @@
package installer
import (
"path/filepath"
)
func (u *Installer) LogFiles() []string {
return []string{
filepath.Join(u.tempDir, LogFile),
}
}

View File

@@ -1,12 +0,0 @@
package installer
import (
"path/filepath"
)
func (u *Installer) LogFiles() []string {
return []string{
filepath.Join(u.tempDir, msiLogFile),
filepath.Join(u.tempDir, LogFile),
}
}

View File

@@ -1,238 +0,0 @@
package installer
import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
log "github.com/sirupsen/logrus"
)
const (
daemonName = "netbird"
updaterBinary = "updater"
uiBinary = "/Applications/NetBird.app"
defaultTempDir = "/var/lib/netbird/tmp-install"
pkgDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
)
var (
binaryExtensions = []string{"pkg"}
)
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
resultHandler := NewResultHandler(u.tempDir)
// Always ensure daemon and UI are restarted after setup
defer func() {
log.Infof("write out result")
var err error
if resultErr == nil {
err = resultHandler.WriteSuccess()
} else {
err = resultHandler.WriteErr(resultErr)
}
if err != nil {
log.Errorf("failed to write update result: %v", err)
}
// skip service restart if dry-run mode is enabled
if dryRun {
return
}
log.Infof("starting daemon back")
if err := u.startDaemon(daemonFolder); err != nil {
log.Errorf("failed to start daemon: %v", err)
}
log.Infof("starting UI back")
if err := u.startUIAsUser(); err != nil {
log.Errorf("failed to start UI: %v", err)
}
}()
if dryRun {
time.Sleep(7 * time.Second)
log.Infof("dry-run mode enabled, skipping actual installation")
resultErr = fmt.Errorf("dry-run mode enabled")
return
}
switch TypeOfInstaller(ctx) {
case TypePKG:
resultErr = u.installPkgFile(ctx, installerFile)
case TypeHomebrew:
resultErr = u.updateHomeBrew(ctx)
}
return resultErr
}
func (u *Installer) startDaemon(daemonFolder string) error {
log.Infof("starting netbird service")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
if output, err := cmd.CombinedOutput(); err != nil {
log.Warnf("failed to start netbird service: %v, output: %s", err, string(output))
return err
}
log.Infof("netbird service started successfully")
return nil
}
func (u *Installer) startUIAsUser() error {
log.Infof("starting netbird-ui: %s", uiBinary)
// Get the current console user
cmd := exec.Command("stat", "-f", "%Su", "/dev/console")
output, err := cmd.Output()
if err != nil {
return fmt.Errorf("failed to get console user: %w", err)
}
username := strings.TrimSpace(string(output))
if username == "" || username == "root" {
return fmt.Errorf("no active user session found")
}
log.Infof("starting UI for user: %s", username)
// Get user's UID
userInfo, err := user.Lookup(username)
if err != nil {
return fmt.Errorf("failed to lookup user %s: %w", username, err)
}
// Start the UI process as the console user using launchctl
// This ensures the app runs in the user's context with proper GUI access
launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary)
log.Infof("launchCmd: %s", launchCmd.String())
// Set the user's home directory for proper macOS app behavior
launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir)
log.Infof("set HOME environment variable: %s", userInfo.HomeDir)
if err := launchCmd.Start(); err != nil {
return fmt.Errorf("failed to start UI process: %w", err)
}
// Release the process so it can run independently
if err := launchCmd.Process.Release(); err != nil {
log.Warnf("failed to release UI process: %v", err)
}
log.Infof("netbird-ui started successfully for user %s", username)
return nil
}
func (u *Installer) installPkgFile(ctx context.Context, path string) error {
log.Infof("installing pkg file: %s", path)
// Kill any existing UI processes before installation
// This ensures the postinstall script's "open $APP" will start the new version
u.killUI()
volume := "/"
cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
if err := cmd.Start(); err != nil {
return fmt.Errorf("error running pkg file: %w", err)
}
log.Infof("installer started with PID %d", cmd.Process.Pid)
if err := cmd.Wait(); err != nil {
return fmt.Errorf("error running pkg file: %w", err)
}
log.Infof("pkg file installed successfully")
return nil
}
func (u *Installer) updateHomeBrew(ctx context.Context) error {
log.Infof("updating homebrew")
// Kill any existing UI processes before upgrade
// This ensures the new version will be started after upgrade
u.killUI()
// Homebrew must be run as a non-root user
// To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
// Check both Apple Silicon and Intel Mac paths
brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/"
brewBinPath := "/opt/homebrew/bin/brew"
if _, err := os.Stat(brewTapPath); os.IsNotExist(err) {
// Try Intel Mac path
brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/"
brewBinPath = "/usr/local/bin/brew"
}
fileInfo, err := os.Stat(brewTapPath)
if err != nil {
return fmt.Errorf("error getting homebrew installation path info: %w", err)
}
fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
if !ok {
return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
}
// Get username from UID
brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
if err != nil {
return fmt.Errorf("error looking up brew installer user: %w", err)
}
userName := brewUser.Username
// Get user HOME, required for brew to run correctly
// https://github.com/Homebrew/brew/issues/15833
homeDir := brewUser.HomeDir
// Check if netbird-ui is installed (must run as the brew user, not root)
checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui")
checkUICmd.Env = append(os.Environ(), "HOME="+homeDir)
uiInstalled := checkUICmd.Run() == nil
// Homebrew does not support installing specific versions
// Thus it will always update to latest and ignore targetVersion
upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"}
if uiInstalled {
upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
}
cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...)
cmd.Env = append(os.Environ(), "HOME="+homeDir)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output))
}
log.Infof("homebrew updated successfully")
return nil
}
func (u *Installer) killUI() {
log.Infof("killing existing netbird-ui processes")
cmd := exec.Command("pkill", "-x", "netbird-ui")
if output, err := cmd.CombinedOutput(); err != nil {
// pkill returns exit code 1 if no processes matched, which is fine
log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output))
} else {
log.Infof("netbird-ui processes killed")
}
}
func urlWithVersionArch(_ Type, version string) string {
url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
}

View File

@@ -1,213 +0,0 @@
package installer
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
daemonName = "netbird.exe"
uiName = "netbird-ui.exe"
updaterBinary = "updater.exe"
msiLogFile = "msi.log"
msiDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
exeDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
)
var (
defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install")
// for the cleanup
binaryExtensions = []string{"msi", "exe"}
)
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
resultHandler := NewResultHandler(u.tempDir)
// Always ensure daemon and UI are restarted after setup
defer func() {
log.Infof("starting daemon back")
if err := u.startDaemon(daemonFolder); err != nil {
log.Errorf("failed to start daemon: %v", err)
}
log.Infof("starting UI back")
if err := u.startUIAsUser(daemonFolder); err != nil {
log.Errorf("failed to start UI: %v", err)
}
log.Infof("write out result")
var err error
if resultErr == nil {
err = resultHandler.WriteSuccess()
} else {
err = resultHandler.WriteErr(resultErr)
}
if err != nil {
log.Errorf("failed to write update result: %v", err)
}
}()
if dryRun {
log.Infof("dry-run mode enabled, skipping actual installation")
resultErr = fmt.Errorf("dry-run mode enabled")
return
}
installerType, err := typeByFileExtension(installerFile)
if err != nil {
log.Debugf("%v", err)
resultErr = err
return
}
var cmd *exec.Cmd
switch installerType {
case TypeExe:
log.Infof("run exe installer: %s", installerFile)
cmd = exec.CommandContext(ctx, installerFile, "/S")
default:
installerDir := filepath.Dir(installerFile)
logPath := filepath.Join(installerDir, msiLogFile)
log.Infof("run msi installer: %s", installerFile)
cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath)
}
cmd.Dir = filepath.Dir(installerFile)
if resultErr = cmd.Start(); resultErr != nil {
log.Errorf("error starting installer: %v", resultErr)
return
}
log.Infof("installer started with PID %d", cmd.Process.Pid)
if resultErr = cmd.Wait(); resultErr != nil {
log.Errorf("installer process finished with error: %v", resultErr)
return
}
return nil
}
func (u *Installer) startDaemon(daemonFolder string) error {
log.Infof("starting netbird service")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
if output, err := cmd.CombinedOutput(); err != nil {
log.Debugf("failed to start netbird service: %v, output: %s", err, string(output))
return err
}
log.Infof("netbird service started successfully")
return nil
}
func (u *Installer) startUIAsUser(daemonFolder string) error {
uiPath := filepath.Join(daemonFolder, uiName)
log.Infof("starting netbird-ui: %s", uiPath)
// Get the active console session ID
sessionID := windows.WTSGetActiveConsoleSessionId()
if sessionID == 0xFFFFFFFF {
return fmt.Errorf("no active user session found")
}
// Get the user token for that session
var userToken windows.Token
err := windows.WTSQueryUserToken(sessionID, &userToken)
if err != nil {
return fmt.Errorf("failed to query user token: %w", err)
}
defer func() {
if err := userToken.Close(); err != nil {
log.Warnf("failed to close user token: %v", err)
}
}()
// Duplicate the token to a primary token
var primaryToken windows.Token
err = windows.DuplicateTokenEx(
userToken,
windows.MAXIMUM_ALLOWED,
nil,
windows.SecurityImpersonation,
windows.TokenPrimary,
&primaryToken,
)
if err != nil {
return fmt.Errorf("failed to duplicate token: %w", err)
}
defer func() {
if err := primaryToken.Close(); err != nil {
log.Warnf("failed to close token: %v", err)
}
}()
// Prepare startup info
var si windows.StartupInfo
si.Cb = uint32(unsafe.Sizeof(si))
si.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
var pi windows.ProcessInformation
cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath))
if err != nil {
return fmt.Errorf("failed to convert path to UTF16: %w", err)
}
creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT
err = windows.CreateProcessAsUser(
primaryToken,
nil,
cmdLine,
nil,
nil,
false,
creationFlags,
nil,
nil,
&si,
&pi,
)
if err != nil {
return fmt.Errorf("CreateProcessAsUser failed: %w", err)
}
// Close handles
if err := windows.CloseHandle(pi.Process); err != nil {
log.Warnf("failed to close process handle: %v", err)
}
if err := windows.CloseHandle(pi.Thread); err != nil {
log.Warnf("failed to close thread handle: %v", err)
}
log.Infof("netbird-ui started successfully in session %d", sessionID)
return nil
}
func urlWithVersionArch(it Type, version string) string {
var url string
if it == TypeExe {
url = exeDownloadURL
} else {
url = msiDownloadURL
}
url = strings.ReplaceAll(url, "%version", version)
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
}

View File

@@ -1,5 +0,0 @@
package installer
const (
LogFile = "installer.log"
)

View File

@@ -1,15 +0,0 @@
package installer
import (
"os/exec"
"syscall"
)
// setUpdaterProcAttr configures the updater process to run in a new session,
// making it independent of the parent daemon process. This ensures the updater
// survives when the daemon is stopped during the pkg installation.
func setUpdaterProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
}
}

View File

@@ -1,14 +0,0 @@
package installer
import (
"os/exec"
"syscall"
)
// setUpdaterProcAttr configures the updater process to run detached from the parent,
// making it independent of the parent daemon process.
func setUpdaterProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS
}
}

View File

@@ -1,7 +0,0 @@
//go:build devartifactsign
package installer
const (
DefaultSigningKeysBaseURL = "http://192.168.0.10:9089/signrepo"
)

View File

@@ -1,7 +0,0 @@
//go:build !devartifactsign
package installer
const (
DefaultSigningKeysBaseURL = "https://publickeys.netbird.io/artifact-signatures"
)

View File

@@ -1,230 +0,0 @@
package installer
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus"
)
const (
resultFile = "result.json"
)
type Result struct {
Success bool
Error string
ExecutedAt time.Time
}
// ResultHandler handles reading and writing update results
type ResultHandler struct {
resultFile string
}
// NewResultHandler creates a new communicator with the given directory path
// The result file will be created as "result.json" in the specified directory
func NewResultHandler(installerDir string) *ResultHandler {
// Create it if it doesn't exist
// do not care if already exists
_ = os.MkdirAll(installerDir, 0o700)
rh := &ResultHandler{
resultFile: filepath.Join(installerDir, resultFile),
}
return rh
}
func (rh *ResultHandler) GetErrorResultReason() string {
result, err := rh.tryReadResult()
if err == nil && !result.Success {
return result.Error
}
if err := rh.cleanup(); err != nil {
log.Warnf("failed to cleanup result file: %v", err)
}
return ""
}
func (rh *ResultHandler) WriteSuccess() error {
result := Result{
Success: true,
ExecutedAt: time.Now(),
}
return rh.write(result)
}
func (rh *ResultHandler) WriteErr(errReason error) error {
result := Result{
Success: false,
Error: errReason.Error(),
ExecutedAt: time.Now(),
}
return rh.write(result)
}
func (rh *ResultHandler) Watch(ctx context.Context) (Result, error) {
log.Infof("start watching result: %s", rh.resultFile)
// Check if file already exists (updater finished before we started watching)
if result, err := rh.tryReadResult(); err == nil {
log.Infof("installer result: %v", result)
return result, nil
}
dir := filepath.Dir(rh.resultFile)
if err := rh.waitForDirectory(ctx, dir); err != nil {
return Result{}, err
}
return rh.watchForResultFile(ctx, dir)
}
func (rh *ResultHandler) waitForDirectory(ctx context.Context, dir string) error {
ticker := time.NewTicker(300 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if info, err := os.Stat(dir); err == nil && info.IsDir() {
return nil
}
}
}
}
func (rh *ResultHandler) watchForResultFile(ctx context.Context, dir string) (Result, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Error(err)
return Result{}, err
}
defer func() {
if err := watcher.Close(); err != nil {
log.Warnf("failed to close watcher: %v", err)
}
}()
if err := watcher.Add(dir); err != nil {
return Result{}, fmt.Errorf("failed to watch directory: %v", err)
}
// Check again after setting up watcher to avoid race condition
// (file could have been created between initial check and watcher setup)
if result, err := rh.tryReadResult(); err == nil {
log.Infof("installer result: %v", result)
return result, nil
}
for {
select {
case <-ctx.Done():
return Result{}, ctx.Err()
case event, ok := <-watcher.Events:
if !ok {
return Result{}, errors.New("watcher closed unexpectedly")
}
if result, done := rh.handleWatchEvent(event); done {
return result, nil
}
case err, ok := <-watcher.Errors:
if !ok {
return Result{}, errors.New("watcher closed unexpectedly")
}
return Result{}, fmt.Errorf("watcher error: %w", err)
}
}
}
func (rh *ResultHandler) handleWatchEvent(event fsnotify.Event) (Result, bool) {
if event.Name != rh.resultFile {
return Result{}, false
}
if event.Has(fsnotify.Create) {
result, err := rh.tryReadResult()
if err != nil {
log.Debugf("error while reading result: %v", err)
return result, true
}
log.Infof("installer result: %v", result)
return result, true
}
return Result{}, false
}
// Write writes the update result to a file for the UI to read
func (rh *ResultHandler) write(result Result) error {
log.Infof("write out installer result to: %s", rh.resultFile)
// Ensure directory exists
dir := filepath.Dir(rh.resultFile)
if err := os.MkdirAll(dir, 0o755); err != nil {
log.Errorf("failed to create directory %s: %v", dir, err)
return err
}
data, err := json.Marshal(result)
if err != nil {
return err
}
// Write to a temporary file first, then rename for atomic operation
tmpPath := rh.resultFile + ".tmp"
if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
log.Errorf("failed to create temp file: %s", err)
return err
}
// Atomic rename
if err := os.Rename(tmpPath, rh.resultFile); err != nil {
if cleanupErr := os.Remove(tmpPath); cleanupErr != nil {
log.Warnf("Failed to remove temp result file: %v", err)
}
return err
}
return nil
}
func (rh *ResultHandler) cleanup() error {
err := os.Remove(rh.resultFile)
if err != nil && !os.IsNotExist(err) {
return err
}
log.Debugf("delete installer result file: %s", rh.resultFile)
return nil
}
// tryReadResult attempts to read and validate the result file
func (rh *ResultHandler) tryReadResult() (Result, error) {
data, err := os.ReadFile(rh.resultFile)
if err != nil {
return Result{}, err
}
var result Result
if err := json.Unmarshal(data, &result); err != nil {
return Result{}, fmt.Errorf("invalid result format: %w", err)
}
if err := rh.cleanup(); err != nil {
log.Warnf("failed to cleanup result file: %v", err)
}
return result, nil
}

View File

@@ -1,14 +0,0 @@
package installer
type Type struct {
name string
downloadable bool
}
func (t Type) String() string {
return t.name
}
func (t Type) Downloadable() bool {
return t.downloadable
}

View File

@@ -1,22 +0,0 @@
package installer
import (
"context"
"os/exec"
)
var (
TypeHomebrew = Type{name: "Homebrew", downloadable: false}
TypePKG = Type{name: "pkg", downloadable: true}
)
func TypeOfInstaller(ctx context.Context) Type {
cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client")
_, err := cmd.Output()
if err != nil && cmd.ProcessState.ExitCode() == 1 {
// Not installed using pkg file, thus installed using Homebrew
return TypeHomebrew
}
return TypePKG
}

View File

@@ -1,51 +0,0 @@
package installer
import (
"context"
"fmt"
"strings"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
)
const (
uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
)
var (
TypeExe = Type{name: "EXE", downloadable: true}
TypeMSI = Type{name: "MSI", downloadable: true}
)
func TypeOfInstaller(_ context.Context) Type {
paths := []string{uninstallKeyPath64, uninstallKeyPath32}
for _, path := range paths {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
if err != nil {
continue
}
if err := k.Close(); err != nil {
log.Warnf("Error closing registry key: %v", err)
}
return TypeExe
}
log.Debug("No registry entry found for Netbird, assuming MSI installation")
return TypeMSI
}
func typeByFileExtension(filePath string) (Type, error) {
switch {
case strings.HasSuffix(strings.ToLower(filePath), ".exe"):
return TypeExe, nil
case strings.HasSuffix(strings.ToLower(filePath), ".msi"):
return TypeMSI, nil
default:
return Type{}, fmt.Errorf("unsupported installer type for file: %s", filePath)
}
}

View File

@@ -1,374 +0,0 @@
//go:build windows || darwin
package updatemanager
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
v "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
const (
latestVersion = "latest"
// this version will be ignored
developmentVersion = "development"
)
var errNoUpdateState = errors.New("no update state found")
type UpdateState struct {
PreUpdateVersion string
TargetVersion string
}
func (u UpdateState) Name() string {
return "autoUpdate"
}
type Manager struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
lastTrigger time.Time
mgmUpdateChan chan struct{}
updateChannel chan struct{}
currentVersion string
update UpdateInterface
wg sync.WaitGroup
cancel context.CancelFunc
expectedVersion *v.Version
updateToLatestVersion bool
// updateMutex protect update and expectedVersion fields
updateMutex sync.Mutex
triggerUpdateFn func(context.Context, string) error
}
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
if runtime.GOOS == "darwin" {
isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
if isBrew {
log.Warnf("auto-update disabled on Home Brew installation")
return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
}
}
return newManager(statusRecorder, stateManager)
}
func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
manager := &Manager{
statusRecorder: statusRecorder,
stateManager: stateManager,
mgmUpdateChan: make(chan struct{}, 1),
updateChannel: make(chan struct{}, 1),
currentVersion: version.NetbirdVersion(),
update: version.NewUpdate("nb/client"),
}
manager.triggerUpdateFn = manager.triggerUpdate
stateManager.RegisterState(&UpdateState{})
return manager, nil
}
// CheckUpdateSuccess checks if the update was successful and send a notification.
// It works without to start the update manager.
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
reason := m.lastResultErrReason()
if reason != "" {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
"Auto-update failed",
fmt.Sprintf("Auto-update failed: %s", reason),
nil,
)
}
updateState, err := m.loadAndDeleteUpdateState(ctx)
if err != nil {
if errors.Is(err, errNoUpdateState) {
return
}
log.Errorf("failed to load update state: %v", err)
return
}
log.Debugf("auto-update state loaded, %v", *updateState)
if updateState.TargetVersion == m.currentVersion {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"Auto-update completed",
fmt.Sprintf("Your NetBird Client was auto-updated to version %s", m.currentVersion),
nil,
)
return
}
}
func (m *Manager) Start(ctx context.Context) {
if m.cancel != nil {
log.Errorf("Manager already started")
return
}
m.update.SetDaemonVersion(m.currentVersion)
m.update.SetOnUpdateListener(func() {
select {
case m.updateChannel <- struct{}{}:
default:
}
})
go m.update.StartFetcher()
ctx, cancel := context.WithCancel(ctx)
m.cancel = cancel
m.wg.Add(1)
go m.updateLoop(ctx)
}
func (m *Manager) SetVersion(expectedVersion string) {
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
if m.cancel == nil {
log.Errorf("manager not started")
return
}
m.updateMutex.Lock()
defer m.updateMutex.Unlock()
if expectedVersion == "" {
log.Errorf("empty expected version provided")
m.expectedVersion = nil
m.updateToLatestVersion = false
return
}
if expectedVersion == latestVersion {
m.updateToLatestVersion = true
m.expectedVersion = nil
} else {
expectedSemVer, err := v.NewVersion(expectedVersion)
if err != nil {
log.Errorf("error parsing version: %v", err)
return
}
if m.expectedVersion != nil && m.expectedVersion.Equal(expectedSemVer) {
return
}
m.expectedVersion = expectedSemVer
m.updateToLatestVersion = false
}
select {
case m.mgmUpdateChan <- struct{}{}:
default:
}
}
func (m *Manager) Stop() {
if m.cancel == nil {
return
}
m.cancel()
m.updateMutex.Lock()
if m.update != nil {
m.update.StopWatch()
m.update = nil
}
m.updateMutex.Unlock()
m.wg.Wait()
}
func (m *Manager) onContextCancel() {
if m.cancel == nil {
return
}
m.updateMutex.Lock()
defer m.updateMutex.Unlock()
if m.update != nil {
m.update.StopWatch()
m.update = nil
}
}
func (m *Manager) updateLoop(ctx context.Context) {
defer m.wg.Done()
for {
select {
case <-ctx.Done():
m.onContextCancel()
return
case <-m.mgmUpdateChan:
case <-m.updateChannel:
log.Infof("fetched new version info")
}
m.handleUpdate(ctx)
}
}
func (m *Manager) handleUpdate(ctx context.Context) {
var updateVersion *v.Version
m.updateMutex.Lock()
if m.update == nil {
m.updateMutex.Unlock()
return
}
expectedVersion := m.expectedVersion
useLatest := m.updateToLatestVersion
curLatestVersion := m.update.LatestVersion()
m.updateMutex.Unlock()
switch {
// Resolve "latest" to actual version
case useLatest:
if curLatestVersion == nil {
log.Tracef("latest version not fetched yet")
return
}
updateVersion = curLatestVersion
// Update to specific version
case expectedVersion != nil:
updateVersion = expectedVersion
default:
log.Debugf("no expected version information set")
return
}
log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
if !m.shouldUpdate(updateVersion) {
return
}
m.lastTrigger = time.Now()
log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"Automatically updating client",
"Your client version is older than auto-update version set in Management, updating client now.",
nil,
)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"",
"",
map[string]string{"progress_window": "show", "version": updateVersion.String()},
)
updateState := UpdateState{
PreUpdateVersion: m.currentVersion,
TargetVersion: updateVersion.String(),
}
if err := m.stateManager.UpdateState(updateState); err != nil {
log.Warnf("failed to update state: %v", err)
} else {
if err = m.stateManager.PersistState(ctx); err != nil {
log.Warnf("failed to persist state: %v", err)
}
}
if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
log.Errorf("Error triggering auto-update: %v", err)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
"Auto-update failed",
fmt.Sprintf("Auto-update failed: %v", err),
nil,
)
}
}
// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
// Returns nil if no state exists.
func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, error) {
stateType := &UpdateState{}
m.stateManager.RegisterState(stateType)
if err := m.stateManager.LoadState(stateType); err != nil {
return nil, fmt.Errorf("load state: %w", err)
}
state := m.stateManager.GetState(stateType)
if state == nil {
return nil, errNoUpdateState
}
updateState, ok := state.(*UpdateState)
if !ok {
return nil, fmt.Errorf("failed to cast state to UpdateState")
}
if err := m.stateManager.DeleteState(updateState); err != nil {
return nil, fmt.Errorf("delete state: %w", err)
}
if err := m.stateManager.PersistState(ctx); err != nil {
return nil, fmt.Errorf("persist state: %w", err)
}
return updateState, nil
}
func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
if m.currentVersion == developmentVersion {
log.Debugf("skipping auto-update, running development version")
return false
}
currentVersion, err := v.NewVersion(m.currentVersion)
if err != nil {
log.Errorf("error checking for update, error parsing version `%s`: %v", m.currentVersion, err)
return false
}
if currentVersion.GreaterThanOrEqual(updateVersion) {
log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", m.currentVersion, updateVersion)
return false
}
if time.Since(m.lastTrigger) < 5*time.Minute {
log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
return false
}
return true
}
func (m *Manager) lastResultErrReason() string {
inst := installer.New()
result := installer.NewResultHandler(inst.TempDir())
return result.GetErrorResultReason()
}
func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
inst := installer.New()
return inst.RunInstallation(ctx, targetVersion)
}

View File

@@ -1,214 +0,0 @@
//go:build windows || darwin
package updatemanager
import (
"context"
"fmt"
"path"
"testing"
"time"
v "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type versionUpdateMock struct {
latestVersion *v.Version
onUpdate func()
}
func (v versionUpdateMock) StopWatch() {}
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
return false
}
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
v.onUpdate = updateFn
}
func (v versionUpdateMock) LatestVersion() *v.Version {
return v.latestVersion
}
func (v versionUpdateMock) StartFetcher() {}
func Test_LatestVersion(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
initialLatestVersion *v.Version
latestVersion *v.Version
shouldUpdateInit bool
shouldUpdateLater bool
}{
{
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
daemonVersion: "1.0.0",
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
latestVersion: v.Must(v.NewSemver("1.0.2")),
shouldUpdateInit: true,
shouldUpdateLater: false,
},
{
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
daemonVersion: "1.0.0",
initialLatestVersion: nil,
latestVersion: v.Must(v.NewSemver("1.0.1")),
shouldUpdateInit: false,
shouldUpdateLater: true,
},
}
for idx, c := range testMatrix {
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = mockUpdate
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion("latest")
var triggeredInit bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.initialLatestVersion.String() {
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
}
triggeredInit = true
case <-time.After(10 * time.Millisecond):
triggeredInit = false
}
if triggeredInit != c.shouldUpdateInit {
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
}
mockUpdate.latestVersion = c.latestVersion
mockUpdate.onUpdate()
var triggeredLater bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
}
triggeredLater = true
case <-time.After(10 * time.Millisecond):
triggeredLater = false
}
if triggeredLater != c.shouldUpdateLater {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
}
m.Stop()
}
}
func Test_HandleUpdate(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
latestVersion *v.Version
expectedVersion string
shouldUpdate bool
}{
{
name: "Update to a specific version should update regardless of if latestVersion is available yet",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.56.0",
shouldUpdate: true,
},
{
name: "Update to specific version should not update if version matches",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.55.0",
shouldUpdate: false,
},
{
name: "Update to specific version should not update if current version is newer",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.54.0",
shouldUpdate: false,
},
{
name: "Update to latest version should update if latest is newer",
daemonVersion: "0.55.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: true,
},
{
name: "Update to latest version should not update if latest == current",
daemonVersion: "0.56.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if daemon version is invalid",
daemonVersion: "development",
latestVersion: v.Must(v.NewSemver("1.0.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expecting latest and latest version is unavailable",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expected version is invalid",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "development",
shouldUpdate: false,
},
}
for idx, c := range testMatrix {
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion(c.expectedVersion)
var updateTriggered bool
select {
case targetVersion := <-targetVersionChan:
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
}
updateTriggered = true
case <-time.After(10 * time.Millisecond):
updateTriggered = false
}
if updateTriggered != c.shouldUpdate {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
}
m.Stop()
}
}

View File

@@ -1,39 +0,0 @@
//go:build !windows && !darwin
package updatemanager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Manager is a no-op stub for unsupported platforms
type Manager struct{}
// NewManager returns a no-op manager for unsupported platforms
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
return nil, fmt.Errorf("update manager is not supported on this platform")
}
// CheckUpdateSuccess is a no-op on unsupported platforms
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
// no-op
}
// Start is a no-op on unsupported platforms
func (m *Manager) Start(ctx context.Context) {
// no-op
}
// SetVersion is a no-op on unsupported platforms
func (m *Manager) SetVersion(expectedVersion string) {
// no-op
}
// Stop is a no-op on unsupported platforms
func (m *Manager) Stop() {
// no-op
}

View File

@@ -1,302 +0,0 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"hash"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/blake2s"
)
const (
tagArtifactPrivate = "ARTIFACT PRIVATE KEY"
tagArtifactPublic = "ARTIFACT PUBLIC KEY"
maxArtifactKeySignatureAge = 10 * 365 * 24 * time.Hour
maxArtifactSignatureAge = 10 * 365 * 24 * time.Hour
)
// ArtifactHash wraps a hash.Hash and counts bytes written
type ArtifactHash struct {
hash.Hash
}
// NewArtifactHash returns an initialized ArtifactHash using BLAKE2s
func NewArtifactHash() *ArtifactHash {
h, err := blake2s.New256(nil)
if err != nil {
panic(err) // Should never happen with nil Key
}
return &ArtifactHash{Hash: h}
}
func (ah *ArtifactHash) Write(b []byte) (int, error) {
return ah.Hash.Write(b)
}
// ArtifactKey is a signing Key used to sign artifacts
type ArtifactKey struct {
PrivateKey
}
func (k ArtifactKey) String() string {
return fmt.Sprintf(
"ArtifactKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
k.Metadata.ID,
k.Metadata.CreatedAt.Format(time.RFC3339),
k.Metadata.ExpiresAt.Format(time.RFC3339),
)
}
func GenerateArtifactKey(rootKey *RootKey, expiration time.Duration) (*ArtifactKey, []byte, []byte, []byte, error) {
// Verify root key is still valid
if !rootKey.Metadata.ExpiresAt.IsZero() && time.Now().After(rootKey.Metadata.ExpiresAt) {
return nil, nil, nil, nil, fmt.Errorf("root key has expired on %s", rootKey.Metadata.ExpiresAt.Format(time.RFC3339))
}
now := time.Now()
expirationTime := now.Add(expiration)
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("generate ed25519 key: %w", err)
}
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: now.UTC(),
ExpiresAt: expirationTime.UTC(),
}
ak := &ArtifactKey{
PrivateKey{
Key: priv,
Metadata: metadata,
},
}
// Marshal PrivateKey struct to JSON
privJSON, err := json.Marshal(ak.PrivateKey)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
}
// Marshal PublicKey struct to JSON
pubKey := PublicKey{
Key: pub,
Metadata: metadata,
}
pubJSON, err := json.Marshal(pubKey)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
}
// Encode to PEM with metadata embedded in bytes
privPEM := pem.EncodeToMemory(&pem.Block{
Type: tagArtifactPrivate,
Bytes: privJSON,
})
pubPEM := pem.EncodeToMemory(&pem.Block{
Type: tagArtifactPublic,
Bytes: pubJSON,
})
// Sign the public key with the root key
signature, err := SignArtifactKey(*rootKey, pubPEM)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to sign artifact key: %w", err)
}
return ak, privPEM, pubPEM, signature, nil
}
func ParseArtifactKey(privKeyPEM []byte) (ArtifactKey, error) {
pk, err := parsePrivateKey(privKeyPEM, tagArtifactPrivate)
if err != nil {
return ArtifactKey{}, fmt.Errorf("failed to parse artifact Key: %w", err)
}
return ArtifactKey{pk}, nil
}
func ParseArtifactPubKey(data []byte) (PublicKey, error) {
pk, _, err := parsePublicKey(data, tagArtifactPublic)
return pk, err
}
func BundleArtifactKeys(rootKey *RootKey, keys []PublicKey) ([]byte, []byte, error) {
if len(keys) == 0 {
return nil, nil, errors.New("no keys to bundle")
}
// Create bundle by concatenating PEM-encoded keys
var pubBundle []byte
for _, pk := range keys {
// Marshal PublicKey struct to JSON
pubJSON, err := json.Marshal(pk)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
}
// Encode to PEM
pubPEM := pem.EncodeToMemory(&pem.Block{
Type: tagArtifactPublic,
Bytes: pubJSON,
})
pubBundle = append(pubBundle, pubPEM...)
}
// Sign the entire bundle with the root key
signature, err := SignArtifactKey(*rootKey, pubBundle)
if err != nil {
return nil, nil, fmt.Errorf("failed to sign artifact key bundle: %w", err)
}
return pubBundle, signature, nil
}
func ValidateArtifactKeys(publicRootKeys []PublicKey, data []byte, signature Signature, revocationList *RevocationList) ([]PublicKey, error) {
now := time.Now().UTC()
if signature.Timestamp.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("signature timestamp is in the future: %v", signature.Timestamp)
log.Debugf("artifact signature error: %v", err)
return nil, err
}
if now.Sub(signature.Timestamp) > maxArtifactKeySignatureAge {
err := fmt.Errorf("signature is too old: %v (created %v)", now.Sub(signature.Timestamp), signature.Timestamp)
log.Debugf("artifact signature error: %v", err)
return nil, err
}
// Reconstruct the signed message: artifact_key_data || timestamp
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
if !verifyAny(publicRootKeys, msg, signature.Signature) {
return nil, errors.New("failed to verify signature of artifact keys")
}
pubKeys, err := parsePublicKeyBundle(data, tagArtifactPublic)
if err != nil {
log.Debugf("failed to parse public keys: %s", err)
return nil, err
}
validKeys := make([]PublicKey, 0, len(pubKeys))
for _, pubKey := range pubKeys {
// Filter out expired keys
if !pubKey.Metadata.ExpiresAt.IsZero() && now.After(pubKey.Metadata.ExpiresAt) {
log.Debugf("Key %s is expired at %v (current time %v)",
pubKey.Metadata.ID, pubKey.Metadata.ExpiresAt, now)
continue
}
if revocationList != nil {
if revTime, revoked := revocationList.Revoked[pubKey.Metadata.ID]; revoked {
log.Debugf("Key %s is revoked as of %v (created %v)",
pubKey.Metadata.ID, revTime, pubKey.Metadata.CreatedAt)
continue
}
}
validKeys = append(validKeys, pubKey)
}
if len(validKeys) == 0 {
log.Debugf("no valid public keys found for artifact keys")
return nil, fmt.Errorf("all %d artifact keys are revoked", len(pubKeys))
}
return validKeys, nil
}
func ValidateArtifact(artifactPubKeys []PublicKey, data []byte, signature Signature) error {
// Validate signature timestamp
now := time.Now().UTC()
if signature.Timestamp.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("artifact signature timestamp is in the future: %v", signature.Timestamp)
log.Debugf("failed to verify signature of artifact: %s", err)
return err
}
if now.Sub(signature.Timestamp) > maxArtifactSignatureAge {
return fmt.Errorf("artifact signature is too old: %v (created %v)",
now.Sub(signature.Timestamp), signature.Timestamp)
}
h := NewArtifactHash()
if _, err := h.Write(data); err != nil {
return fmt.Errorf("failed to hash artifact: %w", err)
}
hash := h.Sum(nil)
// Reconstruct the signed message: hash || length || timestamp
msg := make([]byte, 0, len(hash)+8+8)
msg = append(msg, hash...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
// Find matching Key and verify
for _, keyInfo := range artifactPubKeys {
if keyInfo.Metadata.ID == signature.KeyID {
// Check Key expiration
if !keyInfo.Metadata.ExpiresAt.IsZero() &&
signature.Timestamp.After(keyInfo.Metadata.ExpiresAt) {
return fmt.Errorf("signing Key %s expired at %v, signature from %v",
signature.KeyID, keyInfo.Metadata.ExpiresAt, signature.Timestamp)
}
if ed25519.Verify(keyInfo.Key, msg, signature.Signature) {
log.Debugf("artifact verified successfully with Key: %s", signature.KeyID)
return nil
}
return fmt.Errorf("signature verification failed for Key %s", signature.KeyID)
}
}
return fmt.Errorf("no signing Key found with ID %s", signature.KeyID)
}
func SignData(artifactKey ArtifactKey, data []byte) ([]byte, error) {
if len(data) == 0 { // Check happens too late
return nil, fmt.Errorf("artifact length must be positive, got %d", len(data))
}
h := NewArtifactHash()
if _, err := h.Write(data); err != nil {
return nil, fmt.Errorf("failed to write artifact hash: %w", err)
}
timestamp := time.Now().UTC()
if !artifactKey.Metadata.ExpiresAt.IsZero() && timestamp.After(artifactKey.Metadata.ExpiresAt) {
return nil, fmt.Errorf("artifact key expired at %v", artifactKey.Metadata.ExpiresAt)
}
hash := h.Sum(nil)
// Create message: hash || length || timestamp
msg := make([]byte, 0, len(hash)+8+8)
msg = append(msg, hash...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
sig := ed25519.Sign(artifactKey.Key, msg)
bundle := Signature{
Signature: sig,
Timestamp: timestamp,
KeyID: artifactKey.Metadata.ID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
return json.Marshal(bundle)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +0,0 @@
-----BEGIN ROOT PUBLIC KEY-----
eyJLZXkiOiJoaGIxdGRDSEZNMFBuQWp1b2w2cXJ1QXRFbWFFSlg1QjFsZUNxWmpn
V1pvPSIsIk1ldGFkYXRhIjp7ImlkIjoiOWE0OTg2NmI2MzE2MjNiNCIsImNyZWF0
ZWRfYXQiOiIyMDI1LTExLTI0VDE3OjE1OjI4LjYyNzE3MzE3MVoiLCJleHBpcmVz
X2F0IjoiMjAzNS0xMS0yMlQxNzoxNToyOC42MjcxNzMxNzFaIn19
-----END ROOT PUBLIC KEY-----

View File

@@ -1,6 +0,0 @@
-----BEGIN ROOT PUBLIC KEY-----
eyJLZXkiOiJyTDByVTN2MEFOZUNmbDZraitiUUd3TE1waU5CaUJLdVBWSnZtQzgr
ZS84PSIsIk1ldGFkYXRhIjp7ImlkIjoiMTBkNjQyZTY2N2FmMDNkNCIsImNyZWF0
ZWRfYXQiOiIyMDI1LTExLTIwVDE3OjI5OjI5LjE4MDk0NjMxNloiLCJleHBpcmVz
X2F0IjoiMjAyNi0xMS0yMFQxNzoyOToyOS4xODA5NDYzMTZaIn19
-----END ROOT PUBLIC KEY-----

View File

@@ -1,174 +0,0 @@
// Package reposign implements a cryptographic signing and verification system
// for NetBird software update artifacts. It provides a hierarchical key
// management system with support for key rotation, revocation, and secure
// artifact distribution.
//
// # Architecture
//
// The package uses a two-tier key hierarchy:
//
// - Root Keys: Long-lived keys that sign artifact keys. These are embedded
// in the client binary and establish the root of trust. Root keys should
// be kept offline and highly secured.
//
// - Artifact Keys: Short-lived keys that sign release artifacts (binaries,
// packages, etc.). These are rotated regularly and can be revoked if
// compromised. Artifact keys are signed by root keys and distributed via
// a public repository.
//
// This separation allows for operational flexibility: artifact keys can be
// rotated frequently without requiring client updates, while root keys remain
// stable and embedded in the software.
//
// # Cryptographic Primitives
//
// The package uses strong, modern cryptographic algorithms:
// - Ed25519: Fast, secure digital signatures (no timing attacks)
// - BLAKE2s-256: Fast cryptographic hash for artifacts
// - SHA-256: Key ID generation
// - JSON: Structured key and signature serialization
// - PEM: Standard key encoding format
//
// # Security Features
//
// Timestamp Binding:
// - All signatures include cryptographically-bound timestamps
// - Prevents replay attacks and enforces signature freshness
// - Clock skew tolerance: 5 minutes
//
// Key Expiration:
// - All keys have expiration times
// - Expired keys are automatically rejected
// - Signing with an expired key fails immediately
//
// Key Revocation:
// - Compromised keys can be revoked via a signed revocation list
// - Revocation list is checked during artifact validation
// - Revoked keys are filtered out before artifact verification
//
// # File Structure
//
// The package expects the following file layout in the key repository:
//
// signrepo/
// artifact-key-pub.pem # Bundle of artifact public keys
// artifact-key-pub.pem.sig # Root signature of the bundle
// revocation-list.json # List of revoked key IDs
// revocation-list.json.sig # Root signature of revocation list
//
// And in the artifacts repository:
//
// releases/
// v0.28.0/
// netbird-linux-amd64
// netbird-linux-amd64.sig # Artifact signature
// netbird-darwin-amd64
// netbird-darwin-amd64.sig
// ...
//
// # Embedded Root Keys
//
// Root public keys are embedded in the client binary at compile time:
// - Production keys: certs/ directory
// - Development keys: certsdev/ directory
//
// The build tag determines which keys are embedded:
// - Production builds: //go:build !devartifactsign
// - Development builds: //go:build devartifactsign
//
// This ensures that development artifacts cannot be verified using production
// keys and vice versa.
//
// # Key Rotation Strategies
//
// Root Key Rotation:
//
// Root keys can be rotated without breaking existing clients by leveraging
// the multi-key verification system. The loadEmbeddedPublicKeys function
// reads ALL files from the certs/ directory and accepts signatures from ANY
// of the embedded root keys.
//
// To rotate root keys:
//
// 1. Generate a new root key pair:
// newRootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
//
// 2. Add the new public key to the certs/ directory as a new file:
// certs/
// root-pub-2024.pem # Old key (keep this!)
// root-pub-2025.pem # New key (add this)
//
// 3. Build new client versions with both keys embedded. The verification
// will accept signatures from either key.
//
// 4. Start signing new artifact keys with the new root key. Old clients
// with only the old root key will reject these, but new clients with
// both keys will accept them.
//
// Each file in certs/ can contain a single key or a bundle of keys (multiple
// PEM blocks). The system will parse all keys from all files and use them
// for verification. This provides maximum flexibility for key management.
//
// Important: Never remove all old root keys at once. Always maintain at least
// one overlapping key between releases to ensure smooth transitions.
//
// Artifact Key Rotation:
//
// Artifact keys should be rotated regularly (e.g., every 90 days) using the
// bundling mechanism. The BundleArtifactKeys function allows multiple artifact
// keys to be bundled together in a single signed package, and ValidateArtifact
// will accept signatures from ANY key in the bundle.
//
// To rotate artifact keys smoothly:
//
// 1. Generate a new artifact key while keeping the old one:
// newKey, newPrivPEM, newPubPEM, newSig, err := GenerateArtifactKey(rootKey, 90 * 24 * time.Hour)
// // Keep oldPubPEM and oldKey available
//
// 2. Create a bundle containing both old and new public keys
//
// 3. Upload the bundle and its signature to the key repository:
// signrepo/artifact-key-pub.pem # Contains both keys
// signrepo/artifact-key-pub.pem.sig # Root signature
//
// 4. Start signing new releases with the NEW key, but keep the bundle
// unchanged. Clients will download the bundle (containing both keys)
// and accept signatures from either key.
//
// Key bundle validation workflow:
// 1. Client downloads artifact-key-pub.pem and artifact-key-pub.pem.sig
// 2. ValidateArtifactKeys verifies the bundle signature with ANY embedded root key
// 3. ValidateArtifactKeys parses all public keys from the bundle
// 4. ValidateArtifactKeys filters out expired or revoked keys
// 5. When verifying an artifact, ValidateArtifact tries each key until one succeeds
//
// This multi-key acceptance model enables overlapping validity periods and
// smooth transitions without client update requirements.
//
// # Best Practices
//
// Root Key Management:
// - Generate root keys offline on an air-gapped machine
// - Store root private keys in hardware security modules (HSM) if possible
// - Use separate root keys for production and development
// - Rotate root keys infrequently (e.g., every 5-10 years)
// - Plan for root key rotation: embed multiple root public keys
//
// Artifact Key Management:
// - Rotate artifact keys regularly (e.g., every 90 days)
// - Use separate artifact keys for different release channels if needed
// - Revoke keys immediately upon suspected compromise
// - Bundle multiple artifact keys to enable smooth rotation
//
// Signing Process:
// - Sign artifacts in a secure CI/CD environment
// - Never commit private keys to version control
// - Use environment variables or secret management for keys
// - Verify signatures immediately after signing
//
// Distribution:
// - Serve keys and revocation lists from a reliable CDN
// - Use HTTPS for all key and artifact downloads
// - Monitor download failures and signature verification failures
// - Keep revocation list up to date
package reposign

View File

@@ -1,10 +0,0 @@
//go:build devartifactsign
package reposign
import "embed"
//go:embed certsdev
var embeddedCerts embed.FS
const embeddedCertsDir = "certsdev"

View File

@@ -1,10 +0,0 @@
//go:build !devartifactsign
package reposign
import "embed"
//go:embed certs
var embeddedCerts embed.FS
const embeddedCertsDir = "certs"

View File

@@ -1,171 +0,0 @@
package reposign
import (
"crypto/ed25519"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"time"
)
const (
maxClockSkew = 5 * time.Minute
)
// KeyID is a unique identifier for a Key (first 8 bytes of SHA-256 of public Key)
type KeyID [8]byte
// computeKeyID generates a unique ID from a public Key
func computeKeyID(pub ed25519.PublicKey) KeyID {
h := sha256.Sum256(pub)
var id KeyID
copy(id[:], h[:8])
return id
}
// MarshalJSON implements json.Marshaler for KeyID
func (k KeyID) MarshalJSON() ([]byte, error) {
return json.Marshal(k.String())
}
// UnmarshalJSON implements json.Unmarshaler for KeyID
func (k *KeyID) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
parsed, err := ParseKeyID(s)
if err != nil {
return err
}
*k = parsed
return nil
}
// ParseKeyID parses a hex string (16 hex chars = 8 bytes) into a KeyID.
func ParseKeyID(s string) (KeyID, error) {
var id KeyID
if len(s) != 16 {
return id, fmt.Errorf("invalid KeyID length: got %d, want 16 hex chars (8 bytes)", len(s))
}
b, err := hex.DecodeString(s)
if err != nil {
return id, fmt.Errorf("failed to decode KeyID: %w", err)
}
copy(id[:], b)
return id, nil
}
func (k KeyID) String() string {
return fmt.Sprintf("%x", k[:])
}
// KeyMetadata contains versioning and lifecycle information for a Key
type KeyMetadata struct {
ID KeyID `json:"id"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at,omitempty"` // Optional expiration
}
// PublicKey wraps a public Key with its Metadata
type PublicKey struct {
Key ed25519.PublicKey
Metadata KeyMetadata
}
func parsePublicKeyBundle(bundle []byte, typeTag string) ([]PublicKey, error) {
var keys []PublicKey
for len(bundle) > 0 {
keyInfo, rest, err := parsePublicKey(bundle, typeTag)
if err != nil {
return nil, err
}
keys = append(keys, keyInfo)
bundle = rest
}
if len(keys) == 0 {
return nil, errors.New("no keys found in bundle")
}
return keys, nil
}
func parsePublicKey(data []byte, typeTag string) (PublicKey, []byte, error) {
b, rest := pem.Decode(data)
if b == nil {
return PublicKey{}, nil, errors.New("failed to decode PEM data")
}
if b.Type != typeTag {
return PublicKey{}, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
}
// Unmarshal JSON-embedded format
var pub PublicKey
if err := json.Unmarshal(b.Bytes, &pub); err != nil {
return PublicKey{}, nil, fmt.Errorf("failed to unmarshal public key: %w", err)
}
// Validate key length
if len(pub.Key) != ed25519.PublicKeySize {
return PublicKey{}, nil, fmt.Errorf("incorrect Ed25519 public key size: expected %d, got %d",
ed25519.PublicKeySize, len(pub.Key))
}
// Always recompute ID to ensure integrity
pub.Metadata.ID = computeKeyID(pub.Key)
return pub, rest, nil
}
type PrivateKey struct {
Key ed25519.PrivateKey
Metadata KeyMetadata
}
func parsePrivateKey(data []byte, typeTag string) (PrivateKey, error) {
b, rest := pem.Decode(data)
if b == nil {
return PrivateKey{}, errors.New("failed to decode PEM data")
}
if len(rest) > 0 {
return PrivateKey{}, errors.New("trailing PEM data")
}
if b.Type != typeTag {
return PrivateKey{}, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
}
// Unmarshal JSON-embedded format
var pk PrivateKey
if err := json.Unmarshal(b.Bytes, &pk); err != nil {
return PrivateKey{}, fmt.Errorf("failed to unmarshal private key: %w", err)
}
// Validate key length
if len(pk.Key) != ed25519.PrivateKeySize {
return PrivateKey{}, fmt.Errorf("incorrect Ed25519 private key size: expected %d, got %d",
ed25519.PrivateKeySize, len(pk.Key))
}
return pk, nil
}
func verifyAny(publicRootKeys []PublicKey, msg, sig []byte) bool {
// Verify with root keys
var rootKeys []ed25519.PublicKey
for _, r := range publicRootKeys {
rootKeys = append(rootKeys, r.Key)
}
for _, k := range rootKeys {
if ed25519.Verify(k, msg, sig) {
return true
}
}
return false
}

View File

@@ -1,636 +0,0 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/json"
"encoding/pem"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test KeyID functions
func TestComputeKeyID(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
// Verify it's the first 8 bytes of SHA-256
h := sha256.Sum256(pub)
expectedID := KeyID{}
copy(expectedID[:], h[:8])
assert.Equal(t, expectedID, keyID)
}
func TestComputeKeyID_Deterministic(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
// Computing KeyID multiple times should give the same result
keyID1 := computeKeyID(pub)
keyID2 := computeKeyID(pub)
assert.Equal(t, keyID1, keyID2)
}
func TestComputeKeyID_DifferentKeys(t *testing.T) {
pub1, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID1 := computeKeyID(pub1)
keyID2 := computeKeyID(pub2)
// Different keys should produce different IDs
assert.NotEqual(t, keyID1, keyID2)
}
func TestParseKeyID_Valid(t *testing.T) {
hexStr := "0123456789abcdef"
keyID, err := ParseKeyID(hexStr)
require.NoError(t, err)
expected := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
assert.Equal(t, expected, keyID)
}
func TestParseKeyID_InvalidLength(t *testing.T) {
tests := []struct {
name string
input string
}{
{"too short", "01234567"},
{"too long", "0123456789abcdef00"},
{"empty", ""},
{"odd length", "0123456789abcde"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ParseKeyID(tt.input)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid KeyID length")
})
}
}
func TestParseKeyID_InvalidHex(t *testing.T) {
invalidHex := "0123456789abcxyz" // 'xyz' are not valid hex
_, err := ParseKeyID(invalidHex)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode KeyID")
}
func TestKeyID_String(t *testing.T) {
keyID := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
str := keyID.String()
assert.Equal(t, "0123456789abcdef", str)
}
func TestKeyID_RoundTrip(t *testing.T) {
original := "fedcba9876543210"
keyID, err := ParseKeyID(original)
require.NoError(t, err)
result := keyID.String()
assert.Equal(t, original, result)
}
func TestKeyID_ZeroValue(t *testing.T) {
keyID := KeyID{}
str := keyID.String()
assert.Equal(t, "0000000000000000", str)
}
// Test KeyMetadata
func TestKeyMetadata_JSONMarshaling(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
ExpiresAt: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC),
}
jsonData, err := json.Marshal(metadata)
require.NoError(t, err)
var decoded KeyMetadata
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Equal(t, metadata.ID, decoded.ID)
assert.Equal(t, metadata.CreatedAt.Unix(), decoded.CreatedAt.Unix())
assert.Equal(t, metadata.ExpiresAt.Unix(), decoded.ExpiresAt.Unix())
}
func TestKeyMetadata_NoExpiration(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
ExpiresAt: time.Time{}, // Zero value = no expiration
}
jsonData, err := json.Marshal(metadata)
require.NoError(t, err)
var decoded KeyMetadata
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.True(t, decoded.ExpiresAt.IsZero())
}
// Test PublicKey
func TestPublicKey_JSONMarshaling(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
var decoded PublicKey
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Equal(t, pubKey.Key, decoded.Key)
assert.Equal(t, pubKey.Metadata.ID, decoded.Metadata.ID)
}
// Test parsePublicKey
func TestParsePublicKey_Valid(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
}
pubKey := PublicKey{
Key: pub,
Metadata: metadata,
}
// Marshal to JSON
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
// Encode to PEM
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
// Parse it back
parsed, rest, err := parsePublicKey(pemData, tagRootPublic)
require.NoError(t, err)
assert.Empty(t, rest)
assert.Equal(t, pub, parsed.Key)
assert.Equal(t, metadata.ID, parsed.Metadata.ID)
}
func TestParsePublicKey_InvalidPEM(t *testing.T) {
invalidPEM := []byte("not a PEM")
_, _, err := parsePublicKey(invalidPEM, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode PEM")
}
func TestParsePublicKey_WrongType(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
// Encode with wrong type
pemData := pem.EncodeToMemory(&pem.Block{
Type: "WRONG TYPE",
Bytes: jsonData,
})
_, _, err = parsePublicKey(pemData, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "PEM type")
}
func TestParsePublicKey_InvalidJSON(t *testing.T) {
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: []byte("invalid json"),
})
_, _, err := parsePublicKey(pemData, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to unmarshal")
}
func TestParsePublicKey_InvalidKeySize(t *testing.T) {
// Create a public key with wrong size
pubKey := PublicKey{
Key: []byte{0x01, 0x02, 0x03}, // Too short
Metadata: KeyMetadata{
ID: KeyID{},
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
_, _, err = parsePublicKey(pemData, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "incorrect Ed25519 public key size")
}
func TestParsePublicKey_IDRecomputation(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
// Create a public key with WRONG ID
wrongID := KeyID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: wrongID,
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
// Parse should recompute the correct ID
parsed, _, err := parsePublicKey(pemData, tagRootPublic)
require.NoError(t, err)
correctID := computeKeyID(pub)
assert.Equal(t, correctID, parsed.Metadata.ID)
assert.NotEqual(t, wrongID, parsed.Metadata.ID)
}
// Test parsePublicKeyBundle
func TestParsePublicKeyBundle_Single(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
keys, err := parsePublicKeyBundle(pemData, tagRootPublic)
require.NoError(t, err)
assert.Len(t, keys, 1)
assert.Equal(t, pub, keys[0].Key)
}
func TestParsePublicKeyBundle_Multiple(t *testing.T) {
var bundle []byte
// Create 3 keys
for i := 0; i < 3; i++ {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
bundle = append(bundle, pemData...)
}
keys, err := parsePublicKeyBundle(bundle, tagRootPublic)
require.NoError(t, err)
assert.Len(t, keys, 3)
}
func TestParsePublicKeyBundle_Empty(t *testing.T) {
_, err := parsePublicKeyBundle([]byte{}, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no keys found")
}
func TestParsePublicKeyBundle_Invalid(t *testing.T) {
_, err := parsePublicKeyBundle([]byte("invalid data"), tagRootPublic)
assert.Error(t, err)
}
// Test PrivateKey
func TestPrivateKey_JSONMarshaling(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
var decoded PrivateKey
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Equal(t, privKey.Key, decoded.Key)
assert.Equal(t, privKey.Metadata.ID, decoded.Metadata.ID)
}
// Test parsePrivateKey
func TestParsePrivateKey_Valid(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: jsonData,
})
parsed, err := parsePrivateKey(pemData, tagRootPrivate)
require.NoError(t, err)
assert.Equal(t, priv, parsed.Key)
}
func TestParsePrivateKey_InvalidPEM(t *testing.T) {
_, err := parsePrivateKey([]byte("not a PEM"), tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode PEM")
}
func TestParsePrivateKey_TrailingData(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: jsonData,
})
// Add trailing data
pemData = append(pemData, []byte("extra data")...)
_, err = parsePrivateKey(pemData, tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "trailing PEM data")
}
func TestParsePrivateKey_WrongType(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: "WRONG TYPE",
Bytes: jsonData,
})
_, err = parsePrivateKey(pemData, tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "PEM type")
}
func TestParsePrivateKey_InvalidKeySize(t *testing.T) {
privKey := PrivateKey{
Key: []byte{0x01, 0x02, 0x03}, // Too short
Metadata: KeyMetadata{
ID: KeyID{},
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: jsonData,
})
_, err = parsePrivateKey(pemData, tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
}
// Test verifyAny
func TestVerifyAny_ValidSignature(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv, message)
rootKeys := []PublicKey{
{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
},
}
result := verifyAny(rootKeys, message, signature)
assert.True(t, result)
}
func TestVerifyAny_InvalidSignature(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
invalidSignature := make([]byte, ed25519.SignatureSize)
rootKeys := []PublicKey{
{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
},
}
result := verifyAny(rootKeys, message, invalidSignature)
assert.False(t, result)
}
func TestVerifyAny_MultipleKeys(t *testing.T) {
// Create 3 key pairs
pub1, priv1, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub3, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv1, message)
rootKeys := []PublicKey{
{Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
{Key: pub1, Metadata: KeyMetadata{ID: computeKeyID(pub1)}}, // Correct key in middle
{Key: pub3, Metadata: KeyMetadata{ID: computeKeyID(pub3)}},
}
result := verifyAny(rootKeys, message, signature)
assert.True(t, result)
}
func TestVerifyAny_NoMatchingKey(t *testing.T) {
_, priv1, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv1, message)
// Only include pub2, not pub1
rootKeys := []PublicKey{
{Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
}
result := verifyAny(rootKeys, message, signature)
assert.False(t, result)
}
func TestVerifyAny_EmptyKeys(t *testing.T) {
message := []byte("test message")
signature := make([]byte, ed25519.SignatureSize)
result := verifyAny([]PublicKey{}, message, signature)
assert.False(t, result)
}
func TestVerifyAny_TamperedMessage(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv, message)
rootKeys := []PublicKey{
{Key: pub, Metadata: KeyMetadata{ID: computeKeyID(pub)}},
}
// Verify with different message
tamperedMessage := []byte("different message")
result := verifyAny(rootKeys, tamperedMessage, signature)
assert.False(t, result)
}

View File

@@ -1,229 +0,0 @@
package reposign
import (
"crypto/ed25519"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"time"
log "github.com/sirupsen/logrus"
)
const (
maxRevocationSignatureAge = 10 * 365 * 24 * time.Hour
defaultRevocationListExpiration = 365 * 24 * time.Hour
)
type RevocationList struct {
Revoked map[KeyID]time.Time `json:"revoked"` // KeyID -> revocation time
LastUpdated time.Time `json:"last_updated"` // When the list was last modified
ExpiresAt time.Time `json:"expires_at"` // When the list expires
}
func (rl RevocationList) MarshalJSON() ([]byte, error) {
// Convert map[KeyID]time.Time to map[string]time.Time
strMap := make(map[string]time.Time, len(rl.Revoked))
for k, v := range rl.Revoked {
strMap[k.String()] = v
}
return json.Marshal(map[string]interface{}{
"revoked": strMap,
"last_updated": rl.LastUpdated,
"expires_at": rl.ExpiresAt,
})
}
func (rl *RevocationList) UnmarshalJSON(data []byte) error {
var temp struct {
Revoked map[string]time.Time `json:"revoked"`
LastUpdated time.Time `json:"last_updated"`
ExpiresAt time.Time `json:"expires_at"`
Version int `json:"version"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Convert map[string]time.Time back to map[KeyID]time.Time
rl.Revoked = make(map[KeyID]time.Time, len(temp.Revoked))
for k, v := range temp.Revoked {
kid, err := ParseKeyID(k)
if err != nil {
return fmt.Errorf("failed to parse KeyID %q: %w", k, err)
}
rl.Revoked[kid] = v
}
rl.LastUpdated = temp.LastUpdated
rl.ExpiresAt = temp.ExpiresAt
return nil
}
func ParseRevocationList(data []byte) (*RevocationList, error) {
var rl RevocationList
if err := json.Unmarshal(data, &rl); err != nil {
return nil, fmt.Errorf("failed to unmarshal revocation list: %w", err)
}
// Initialize the map if it's nil (in case of empty JSON object)
if rl.Revoked == nil {
rl.Revoked = make(map[KeyID]time.Time)
}
if rl.LastUpdated.IsZero() {
return nil, fmt.Errorf("revocation list missing last_updated timestamp")
}
if rl.ExpiresAt.IsZero() {
return nil, fmt.Errorf("revocation list missing expires_at timestamp")
}
return &rl, nil
}
func ValidateRevocationList(publicRootKeys []PublicKey, data []byte, signature Signature) (*RevocationList, error) {
revoList, err := ParseRevocationList(data)
if err != nil {
log.Debugf("failed to parse revocation list: %s", err)
return nil, err
}
now := time.Now().UTC()
// Validate signature timestamp
if signature.Timestamp.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("revocation signature timestamp is in the future: %v", signature.Timestamp)
log.Debugf("revocation list signature error: %v", err)
return nil, err
}
if now.Sub(signature.Timestamp) > maxRevocationSignatureAge {
err := fmt.Errorf("revocation list signature is too old: %v (created %v)",
now.Sub(signature.Timestamp), signature.Timestamp)
log.Debugf("revocation list signature error: %v", err)
return nil, err
}
// Ensure LastUpdated is not in the future (with clock skew tolerance)
if revoList.LastUpdated.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("revocation list LastUpdated is in the future: %v", revoList.LastUpdated)
log.Errorf("rejecting future-dated revocation list: %v", err)
return nil, err
}
// Check if the revocation list has expired
if now.After(revoList.ExpiresAt) {
err := fmt.Errorf("revocation list expired at %v (current time: %v)", revoList.ExpiresAt, now)
log.Errorf("rejecting expired revocation list: %v", err)
return nil, err
}
// Ensure ExpiresAt is not in the future by more than the expected expiration window
// (allows some clock skew but prevents maliciously long expiration times)
if revoList.ExpiresAt.After(now.Add(maxRevocationSignatureAge)) {
err := fmt.Errorf("revocation list ExpiresAt is too far in the future: %v", revoList.ExpiresAt)
log.Errorf("rejecting revocation list with invalid expiration: %v", err)
return nil, err
}
// Validate signature timestamp is close to LastUpdated
// (prevents signing old lists with new timestamps)
timeDiff := signature.Timestamp.Sub(revoList.LastUpdated).Abs()
if timeDiff > maxClockSkew {
err := fmt.Errorf("signature timestamp %v differs too much from list LastUpdated %v (diff: %v)",
signature.Timestamp, revoList.LastUpdated, timeDiff)
log.Errorf("timestamp mismatch in revocation list: %v", err)
return nil, err
}
// Reconstruct the signed message: revocation_list_data || timestamp || version
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
if !verifyAny(publicRootKeys, msg, signature.Signature) {
return nil, errors.New("revocation list verification failed")
}
return revoList, nil
}
func CreateRevocationList(privateRootKey RootKey, expiration time.Duration) ([]byte, []byte, error) {
now := time.Now()
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: now.UTC(),
ExpiresAt: now.Add(expiration).UTC(),
}
signature, err := signRevocationList(privateRootKey, rl)
if err != nil {
return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err)
}
rlData, err := json.Marshal(&rl)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err)
}
signData, err := json.Marshal(signature)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal signature: %w", err)
}
return rlData, signData, nil
}
func ExtendRevocationList(privateRootKey RootKey, rl RevocationList, kid KeyID, expiration time.Duration) ([]byte, []byte, error) {
now := time.Now().UTC()
rl.Revoked[kid] = now
rl.LastUpdated = now
rl.ExpiresAt = now.Add(expiration)
signature, err := signRevocationList(privateRootKey, rl)
if err != nil {
return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err)
}
rlData, err := json.Marshal(&rl)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err)
}
signData, err := json.Marshal(signature)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal signature: %w", err)
}
return rlData, signData, nil
}
func signRevocationList(privateRootKey RootKey, rl RevocationList) (*Signature, error) {
data, err := json.Marshal(rl)
if err != nil {
return nil, fmt.Errorf("failed to marshal revocation list for signing: %w", err)
}
timestamp := time.Now().UTC()
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
sig := ed25519.Sign(privateRootKey.Key, msg)
signature := &Signature{
Signature: sig,
Timestamp: timestamp,
KeyID: privateRootKey.Metadata.ID,
Algorithm: "ed25519",
HashAlgo: "sha512",
}
return signature, nil
}

View File

@@ -1,860 +0,0 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test RevocationList marshaling/unmarshaling
func TestRevocationList_MarshalJSON(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
expiresAt := time.Date(2024, 4, 15, 11, 0, 0, 0, time.UTC)
rl := &RevocationList{
Revoked: map[KeyID]time.Time{
keyID: revokedTime,
},
LastUpdated: lastUpdated,
ExpiresAt: expiresAt,
}
jsonData, err := json.Marshal(rl)
require.NoError(t, err)
// Verify it can be unmarshaled back
var decoded map[string]interface{}
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Contains(t, decoded, "revoked")
assert.Contains(t, decoded, "last_updated")
assert.Contains(t, decoded, "expires_at")
}
func TestRevocationList_UnmarshalJSON(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
jsonData := map[string]interface{}{
"revoked": map[string]string{
keyID.String(): revokedTime.Format(time.RFC3339),
},
"last_updated": lastUpdated.Format(time.RFC3339),
}
jsonBytes, err := json.Marshal(jsonData)
require.NoError(t, err)
var rl RevocationList
err = json.Unmarshal(jsonBytes, &rl)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 1)
assert.Contains(t, rl.Revoked, keyID)
assert.Equal(t, lastUpdated.Unix(), rl.LastUpdated.Unix())
}
func TestRevocationList_MarshalUnmarshal_Roundtrip(t *testing.T) {
pub1, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID1 := computeKeyID(pub1)
keyID2 := computeKeyID(pub2)
original := &RevocationList{
Revoked: map[KeyID]time.Time{
keyID1: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
keyID2: time.Date(2024, 2, 20, 14, 45, 0, 0, time.UTC),
},
LastUpdated: time.Date(2024, 2, 20, 15, 0, 0, 0, time.UTC),
}
// Marshal
jsonData, err := original.MarshalJSON()
require.NoError(t, err)
// Unmarshal
var decoded RevocationList
err = decoded.UnmarshalJSON(jsonData)
require.NoError(t, err)
// Verify
assert.Len(t, decoded.Revoked, 2)
assert.Equal(t, original.Revoked[keyID1].Unix(), decoded.Revoked[keyID1].Unix())
assert.Equal(t, original.Revoked[keyID2].Unix(), decoded.Revoked[keyID2].Unix())
assert.Equal(t, original.LastUpdated.Unix(), decoded.LastUpdated.Unix())
}
func TestRevocationList_UnmarshalJSON_InvalidKeyID(t *testing.T) {
jsonData := []byte(`{
"revoked": {
"invalid_key_id": "2024-01-15T10:30:00Z"
},
"last_updated": "2024-01-15T11:00:00Z"
}`)
var rl RevocationList
err := json.Unmarshal(jsonData, &rl)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse KeyID")
}
func TestRevocationList_EmptyRevoked(t *testing.T) {
rl := &RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: time.Now().UTC(),
}
jsonData, err := rl.MarshalJSON()
require.NoError(t, err)
var decoded RevocationList
err = decoded.UnmarshalJSON(jsonData)
require.NoError(t, err)
assert.Empty(t, decoded.Revoked)
assert.NotNil(t, decoded.Revoked)
}
// Test ParseRevocationList
func TestParseRevocationList_Valid(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
rl := RevocationList{
Revoked: map[KeyID]time.Time{
keyID: revokedTime,
},
LastUpdated: lastUpdated,
ExpiresAt: time.Date(2025, 2, 20, 14, 45, 0, 0, time.UTC),
}
jsonData, err := rl.MarshalJSON()
require.NoError(t, err)
parsed, err := ParseRevocationList(jsonData)
require.NoError(t, err)
assert.NotNil(t, parsed)
assert.Len(t, parsed.Revoked, 1)
assert.Equal(t, lastUpdated.Unix(), parsed.LastUpdated.Unix())
}
func TestParseRevocationList_InvalidJSON(t *testing.T) {
invalidJSON := []byte("not valid json")
_, err := ParseRevocationList(invalidJSON)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to unmarshal")
}
func TestParseRevocationList_MissingLastUpdated(t *testing.T) {
jsonData := []byte(`{
"revoked": {}
}`)
_, err := ParseRevocationList(jsonData)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing last_updated")
}
func TestParseRevocationList_EmptyObject(t *testing.T) {
jsonData := []byte(`{}`)
_, err := ParseRevocationList(jsonData)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing last_updated")
}
func TestParseRevocationList_NilRevoked(t *testing.T) {
lastUpdated := time.Now().UTC()
expiresAt := lastUpdated.Add(90 * 24 * time.Hour)
jsonData := []byte(`{
"last_updated": "` + lastUpdated.Format(time.RFC3339) + `",
"expires_at": "` + expiresAt.Format(time.RFC3339) + `"
}`)
parsed, err := ParseRevocationList(jsonData)
require.NoError(t, err)
assert.NotNil(t, parsed.Revoked)
assert.Empty(t, parsed.Revoked)
}
func TestParseRevocationList_MissingExpiresAt(t *testing.T) {
lastUpdated := time.Now().UTC()
jsonData := []byte(`{
"revoked": {},
"last_updated": "` + lastUpdated.Format(time.RFC3339) + `"
}`)
_, err := ParseRevocationList(jsonData)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing expires_at")
}
// Test ValidateRevocationList
func TestValidateRevocationList_Valid(t *testing.T) {
// Generate root key
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
signature, err := ParseSignature(sigData)
require.NoError(t, err)
// Validate
rl, err := ValidateRevocationList(rootKeys, rlData, *signature)
require.NoError(t, err)
assert.NotNil(t, rl)
assert.Empty(t, rl.Revoked)
}
func TestValidateRevocationList_InvalidSignature(t *testing.T) {
// Generate root key
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
// Create invalid signature
invalidSig := Signature{
Signature: make([]byte, 64),
Timestamp: time.Now().UTC(),
KeyID: computeKeyID(rootPub),
Algorithm: "ed25519",
HashAlgo: "sha512",
}
// Validate should fail
_, err = ValidateRevocationList(rootKeys, rlData, invalidSig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "verification failed")
}
func TestValidateRevocationList_FutureTimestamp(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
signature, err := ParseSignature(sigData)
require.NoError(t, err)
// Modify timestamp to be in the future
signature.Timestamp = time.Now().UTC().Add(10 * time.Minute)
_, err = ValidateRevocationList(rootKeys, rlData, *signature)
assert.Error(t, err)
assert.Contains(t, err.Error(), "in the future")
}
func TestValidateRevocationList_TooOld(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
signature, err := ParseSignature(sigData)
require.NoError(t, err)
// Modify timestamp to be too old
signature.Timestamp = time.Now().UTC().Add(-20 * 365 * 24 * time.Hour)
_, err = ValidateRevocationList(rootKeys, rlData, *signature)
assert.Error(t, err)
assert.Contains(t, err.Error(), "too old")
}
func TestValidateRevocationList_InvalidJSON(t *testing.T) {
rootPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
signature := Signature{
Signature: make([]byte, 64),
Timestamp: time.Now().UTC(),
KeyID: computeKeyID(rootPub),
Algorithm: "ed25519",
HashAlgo: "sha512",
}
_, err = ValidateRevocationList(rootKeys, []byte("invalid json"), signature)
assert.Error(t, err)
}
func TestValidateRevocationList_FutureLastUpdated(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list with future LastUpdated
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: time.Now().UTC().Add(10 * time.Minute),
ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour),
}
rlData, err := json.Marshal(rl)
require.NoError(t, err)
// Sign it
sig, err := signRevocationList(rootKey, rl)
require.NoError(t, err)
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "LastUpdated is in the future")
}
func TestValidateRevocationList_TimestampMismatch(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list with LastUpdated far in the past
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: time.Now().UTC().Add(-1 * time.Hour),
ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour),
}
rlData, err := json.Marshal(rl)
require.NoError(t, err)
// Sign it with current timestamp
sig, err := signRevocationList(rootKey, rl)
require.NoError(t, err)
// Modify signature timestamp to differ too much from LastUpdated
sig.Timestamp = time.Now().UTC()
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "differs too much")
}
func TestValidateRevocationList_Expired(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list that expired in the past
now := time.Now().UTC()
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: now.Add(-100 * 24 * time.Hour),
ExpiresAt: now.Add(-10 * 24 * time.Hour), // Expired 10 days ago
}
rlData, err := json.Marshal(rl)
require.NoError(t, err)
// Sign it
sig, err := signRevocationList(rootKey, rl)
require.NoError(t, err)
// Adjust signature timestamp to match LastUpdated
sig.Timestamp = rl.LastUpdated
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired")
}
func TestValidateRevocationList_ExpiresAtTooFarInFuture(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list with ExpiresAt too far in the future (beyond maxRevocationSignatureAge)
now := time.Now().UTC()
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: now,
ExpiresAt: now.Add(15 * 365 * 24 * time.Hour), // 15 years in the future
}
rlData, err := json.Marshal(rl)
require.NoError(t, err)
// Sign it
sig, err := signRevocationList(rootKey, rl)
require.NoError(t, err)
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "too far in the future")
}
// Test CreateRevocationList
func TestCreateRevocationList_Valid(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
assert.NotEmpty(t, rlData)
assert.NotEmpty(t, sigData)
// Verify it can be parsed
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
assert.Empty(t, rl.Revoked)
assert.False(t, rl.LastUpdated.IsZero())
// Verify signature can be parsed
sig, err := ParseSignature(sigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
}
// Test ExtendRevocationList
func TestExtendRevocationList_AddKey(t *testing.T) {
// Generate root key
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create empty revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
assert.Empty(t, rl.Revoked)
// Generate a key to revoke
revokedPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
revokedKeyID := computeKeyID(revokedPub)
// Extend the revocation list
newRLData, newSigData, err := ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration)
require.NoError(t, err)
// Verify the new list
newRL, err := ParseRevocationList(newRLData)
require.NoError(t, err)
assert.Len(t, newRL.Revoked, 1)
assert.Contains(t, newRL.Revoked, revokedKeyID)
// Verify signature
sig, err := ParseSignature(newSigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
}
func TestExtendRevocationList_MultipleKeys(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create empty revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
// Add first key
key1Pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
key1ID := computeKeyID(key1Pub)
rlData, _, err = ExtendRevocationList(rootKey, *rl, key1ID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 1)
// Add second key
key2Pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
key2ID := computeKeyID(key2Pub)
rlData, _, err = ExtendRevocationList(rootKey, *rl, key2ID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 2)
assert.Contains(t, rl.Revoked, key1ID)
assert.Contains(t, rl.Revoked, key2ID)
}
func TestExtendRevocationList_DuplicateKey(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create empty revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
// Add a key
keyPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(keyPub)
rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
firstRevocationTime := rl.Revoked[keyID]
// Wait a bit
time.Sleep(10 * time.Millisecond)
// Add the same key again
rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 1)
// The revocation time should be updated
secondRevocationTime := rl.Revoked[keyID]
assert.True(t, secondRevocationTime.After(firstRevocationTime) || secondRevocationTime.Equal(firstRevocationTime))
}
func TestExtendRevocationList_UpdatesLastUpdated(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
firstLastUpdated := rl.LastUpdated
// Wait a bit
time.Sleep(10 * time.Millisecond)
// Extend list
keyPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(keyPub)
rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
// LastUpdated should be updated
assert.True(t, rl.LastUpdated.After(firstLastUpdated))
}
// Integration test
func TestRevocationList_FullWorkflow(t *testing.T) {
// Create root key
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Step 1: Create empty revocation list
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
// Step 2: Validate it
sig, err := ParseSignature(sigData)
require.NoError(t, err)
rl, err := ValidateRevocationList(rootKeys, rlData, *sig)
require.NoError(t, err)
assert.Empty(t, rl.Revoked)
// Step 3: Revoke a key
revokedPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
revokedKeyID := computeKeyID(revokedPub)
rlData, sigData, err = ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration)
require.NoError(t, err)
// Step 4: Validate the extended list
sig, err = ParseSignature(sigData)
require.NoError(t, err)
rl, err = ValidateRevocationList(rootKeys, rlData, *sig)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 1)
assert.Contains(t, rl.Revoked, revokedKeyID)
// Step 5: Verify the revocation time is reasonable
revTime := rl.Revoked[revokedKeyID]
now := time.Now().UTC()
assert.True(t, revTime.Before(now) || revTime.Equal(now))
assert.True(t, now.Sub(revTime) < time.Minute)
}

View File

@@ -1,120 +0,0 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
"encoding/json"
"encoding/pem"
"fmt"
"time"
)
const (
tagRootPrivate = "ROOT PRIVATE KEY"
tagRootPublic = "ROOT PUBLIC KEY"
)
// RootKey is a root Key used to sign signing keys
type RootKey struct {
PrivateKey
}
func (k RootKey) String() string {
return fmt.Sprintf(
"RootKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
k.Metadata.ID,
k.Metadata.CreatedAt.Format(time.RFC3339),
k.Metadata.ExpiresAt.Format(time.RFC3339),
)
}
func ParseRootKey(privKeyPEM []byte) (*RootKey, error) {
pk, err := parsePrivateKey(privKeyPEM, tagRootPrivate)
if err != nil {
return nil, fmt.Errorf("failed to parse root Key: %w", err)
}
return &RootKey{pk}, nil
}
// ParseRootPublicKey parses a root public key from PEM format
func ParseRootPublicKey(pubKeyPEM []byte) (PublicKey, error) {
pk, _, err := parsePublicKey(pubKeyPEM, tagRootPublic)
if err != nil {
return PublicKey{}, fmt.Errorf("failed to parse root public key: %w", err)
}
return pk, nil
}
// GenerateRootKey generates a new root Key pair with Metadata
func GenerateRootKey(expiration time.Duration) (*RootKey, []byte, []byte, error) {
now := time.Now()
expirationTime := now.Add(expiration)
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, nil, err
}
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: now.UTC(),
ExpiresAt: expirationTime.UTC(),
}
rk := &RootKey{
PrivateKey{
Key: priv,
Metadata: metadata,
},
}
// Marshal PrivateKey struct to JSON
privJSON, err := json.Marshal(rk.PrivateKey)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
}
// Marshal PublicKey struct to JSON
pubKey := PublicKey{
Key: pub,
Metadata: metadata,
}
pubJSON, err := json.Marshal(pubKey)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
}
// Encode to PEM with metadata embedded in bytes
privPEM := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: privJSON,
})
pubPEM := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: pubJSON,
})
return rk, privPEM, pubPEM, nil
}
func SignArtifactKey(rootKey RootKey, data []byte) ([]byte, error) {
timestamp := time.Now().UTC()
// This ensures the timestamp is cryptographically bound to the signature
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
sig := ed25519.Sign(rootKey.Key, msg)
// Create signature bundle with timestamp and Metadata
bundle := Signature{
Signature: sig,
Timestamp: timestamp,
KeyID: rootKey.Metadata.ID,
Algorithm: "ed25519",
HashAlgo: "sha512",
}
return json.Marshal(bundle)
}

View File

@@ -1,476 +0,0 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
"encoding/json"
"encoding/pem"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test RootKey.String()
func TestRootKey_String(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
expiresAt := time.Date(2034, 1, 15, 10, 30, 0, 0, time.UTC)
rk := RootKey{
PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: createdAt,
ExpiresAt: expiresAt,
},
},
}
str := rk.String()
assert.Contains(t, str, "RootKey")
assert.Contains(t, str, computeKeyID(pub).String())
assert.Contains(t, str, "2024-01-15")
assert.Contains(t, str, "2034-01-15")
}
func TestRootKey_String_NoExpiration(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
rk := RootKey{
PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: createdAt,
ExpiresAt: time.Time{}, // No expiration
},
},
}
str := rk.String()
assert.Contains(t, str, "RootKey")
assert.Contains(t, str, "0001-01-01") // Zero time format
}
// Test GenerateRootKey
func TestGenerateRootKey_Valid(t *testing.T) {
expiration := 10 * 365 * 24 * time.Hour // 10 years
rk, privPEM, pubPEM, err := GenerateRootKey(expiration)
require.NoError(t, err)
assert.NotNil(t, rk)
assert.NotEmpty(t, privPEM)
assert.NotEmpty(t, pubPEM)
// Verify the key has correct metadata
assert.False(t, rk.Metadata.CreatedAt.IsZero())
assert.False(t, rk.Metadata.ExpiresAt.IsZero())
assert.True(t, rk.Metadata.ExpiresAt.After(rk.Metadata.CreatedAt))
// Verify expiration is approximately correct
expectedExpiration := time.Now().Add(expiration)
timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration)
assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute)
}
func TestGenerateRootKey_ShortExpiration(t *testing.T) {
expiration := 24 * time.Hour // 1 day
rk, _, _, err := GenerateRootKey(expiration)
require.NoError(t, err)
assert.NotNil(t, rk)
// Verify expiration
expectedExpiration := time.Now().Add(expiration)
timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration)
assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute)
}
func TestGenerateRootKey_ZeroExpiration(t *testing.T) {
rk, _, _, err := GenerateRootKey(0)
require.NoError(t, err)
assert.NotNil(t, rk)
// With zero expiration, ExpiresAt should be equal to CreatedAt
assert.Equal(t, rk.Metadata.CreatedAt, rk.Metadata.ExpiresAt)
}
func TestGenerateRootKey_PEMFormat(t *testing.T) {
rk, privPEM, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Verify private key PEM
privBlock, _ := pem.Decode(privPEM)
require.NotNil(t, privBlock)
assert.Equal(t, tagRootPrivate, privBlock.Type)
var privKey PrivateKey
err = json.Unmarshal(privBlock.Bytes, &privKey)
require.NoError(t, err)
assert.Equal(t, rk.Key, privKey.Key)
// Verify public key PEM
pubBlock, _ := pem.Decode(pubPEM)
require.NotNil(t, pubBlock)
assert.Equal(t, tagRootPublic, pubBlock.Type)
var pubKey PublicKey
err = json.Unmarshal(pubBlock.Bytes, &pubKey)
require.NoError(t, err)
assert.Equal(t, rk.Metadata.ID, pubKey.Metadata.ID)
}
func TestGenerateRootKey_KeySize(t *testing.T) {
rk, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Ed25519 private key should be 64 bytes
assert.Equal(t, ed25519.PrivateKeySize, len(rk.Key))
// Ed25519 public key should be 32 bytes
pubKey := rk.Key.Public().(ed25519.PublicKey)
assert.Equal(t, ed25519.PublicKeySize, len(pubKey))
}
func TestGenerateRootKey_UniqueKeys(t *testing.T) {
rk1, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rk2, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Different keys should have different IDs
assert.NotEqual(t, rk1.Metadata.ID, rk2.Metadata.ID)
assert.NotEqual(t, rk1.Key, rk2.Key)
}
// Test ParseRootKey
func TestParseRootKey_Valid(t *testing.T) {
original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
parsed, err := ParseRootKey(privPEM)
require.NoError(t, err)
assert.NotNil(t, parsed)
// Verify the parsed key matches the original
assert.Equal(t, original.Key, parsed.Key)
assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID)
assert.Equal(t, original.Metadata.CreatedAt.Unix(), parsed.Metadata.CreatedAt.Unix())
assert.Equal(t, original.Metadata.ExpiresAt.Unix(), parsed.Metadata.ExpiresAt.Unix())
}
func TestParseRootKey_InvalidPEM(t *testing.T) {
_, err := ParseRootKey([]byte("not a valid PEM"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse")
}
func TestParseRootKey_EmptyData(t *testing.T) {
_, err := ParseRootKey([]byte{})
assert.Error(t, err)
}
func TestParseRootKey_WrongType(t *testing.T) {
// Generate an artifact key instead of root key
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
artifactKey, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
// Try to parse artifact key as root key
_, err = ParseRootKey(privPEM)
assert.Error(t, err)
assert.Contains(t, err.Error(), "PEM type")
// Just to use artifactKey to avoid unused variable warning
_ = artifactKey
}
func TestParseRootKey_CorruptedJSON(t *testing.T) {
// Create PEM with corrupted JSON
corruptedPEM := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: []byte("corrupted json data"),
})
_, err := ParseRootKey(corruptedPEM)
assert.Error(t, err)
}
func TestParseRootKey_InvalidKeySize(t *testing.T) {
// Create a key with invalid size
invalidKey := PrivateKey{
Key: []byte{0x01, 0x02, 0x03}, // Too short
Metadata: KeyMetadata{
ID: KeyID{},
CreatedAt: time.Now().UTC(),
},
}
privJSON, err := json.Marshal(invalidKey)
require.NoError(t, err)
invalidPEM := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: privJSON,
})
_, err = ParseRootKey(invalidPEM)
assert.Error(t, err)
assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
}
func TestParseRootKey_Roundtrip(t *testing.T) {
// Generate a key
original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Parse it
parsed, err := ParseRootKey(privPEM)
require.NoError(t, err)
// Generate PEM again from parsed key
privJSON2, err := json.Marshal(parsed.PrivateKey)
require.NoError(t, err)
privPEM2 := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: privJSON2,
})
// Parse again
parsed2, err := ParseRootKey(privPEM2)
require.NoError(t, err)
// Should still match original
assert.Equal(t, original.Key, parsed2.Key)
assert.Equal(t, original.Metadata.ID, parsed2.Metadata.ID)
}
// Test SignArtifactKey
func TestSignArtifactKey_Valid(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
data := []byte("test data to sign")
sigData, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
assert.NotEmpty(t, sigData)
// Parse and verify signature
sig, err := ParseSignature(sigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
assert.Equal(t, rootKey.Metadata.ID, sig.KeyID)
assert.Equal(t, "ed25519", sig.Algorithm)
assert.Equal(t, "sha512", sig.HashAlgo)
assert.False(t, sig.Timestamp.IsZero())
}
func TestSignArtifactKey_EmptyData(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
sigData, err := SignArtifactKey(*rootKey, []byte{})
require.NoError(t, err)
assert.NotEmpty(t, sigData)
// Should still be able to parse
sig, err := ParseSignature(sigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
}
func TestSignArtifactKey_Verify(t *testing.T) {
rootKey, _, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Parse public key
pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic)
require.NoError(t, err)
// Sign some data
data := []byte("test data for verification")
sigData, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
// Parse signature
sig, err := ParseSignature(sigData)
require.NoError(t, err)
// Reconstruct message
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix()))
// Verify signature
valid := ed25519.Verify(pubKey.Key, msg, sig.Signature)
assert.True(t, valid)
}
func TestSignArtifactKey_DifferentData(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
data1 := []byte("data1")
data2 := []byte("data2")
sig1, err := SignArtifactKey(*rootKey, data1)
require.NoError(t, err)
sig2, err := SignArtifactKey(*rootKey, data2)
require.NoError(t, err)
// Different data should produce different signatures
assert.NotEqual(t, sig1, sig2)
}
func TestSignArtifactKey_MultipleSignatures(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
data := []byte("test data")
// Sign twice with a small delay
sig1, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond)
sig2, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
// Signatures should be different due to different timestamps
assert.NotEqual(t, sig1, sig2)
// Parse both signatures
parsed1, err := ParseSignature(sig1)
require.NoError(t, err)
parsed2, err := ParseSignature(sig2)
require.NoError(t, err)
// Timestamps should be different
assert.True(t, parsed2.Timestamp.After(parsed1.Timestamp))
}
func TestSignArtifactKey_LargeData(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Create 1MB of data
largeData := make([]byte, 1024*1024)
for i := range largeData {
largeData[i] = byte(i % 256)
}
sigData, err := SignArtifactKey(*rootKey, largeData)
require.NoError(t, err)
assert.NotEmpty(t, sigData)
// Verify signature can be parsed
sig, err := ParseSignature(sigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
}
func TestSignArtifactKey_TimestampInSignature(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
beforeSign := time.Now().UTC()
data := []byte("test data")
sigData, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
afterSign := time.Now().UTC()
sig, err := ParseSignature(sigData)
require.NoError(t, err)
// Timestamp should be between before and after
assert.True(t, sig.Timestamp.After(beforeSign.Add(-time.Second)))
assert.True(t, sig.Timestamp.Before(afterSign.Add(time.Second)))
}
// Integration test
func TestRootKey_FullWorkflow(t *testing.T) {
// Step 1: Generate root key
rootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
assert.NotNil(t, rootKey)
assert.NotEmpty(t, privPEM)
assert.NotEmpty(t, pubPEM)
// Step 2: Parse the private key back
parsedRootKey, err := ParseRootKey(privPEM)
require.NoError(t, err)
assert.Equal(t, rootKey.Key, parsedRootKey.Key)
assert.Equal(t, rootKey.Metadata.ID, parsedRootKey.Metadata.ID)
// Step 3: Generate an artifact key using root key
artifactKey, _, artifactPubPEM, artifactSig, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
assert.NotNil(t, artifactKey)
// Step 4: Verify the artifact key signature
pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic)
require.NoError(t, err)
sig, err := ParseSignature(artifactSig)
require.NoError(t, err)
artifactPubKey, _, err := parsePublicKey(artifactPubPEM, tagArtifactPublic)
require.NoError(t, err)
// Reconstruct message - SignArtifactKey signs the PEM, not the JSON
msg := make([]byte, 0, len(artifactPubPEM)+8)
msg = append(msg, artifactPubPEM...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix()))
// Verify with root public key
valid := ed25519.Verify(pubKey.Key, msg, sig.Signature)
assert.True(t, valid, "Artifact key signature should be valid")
// Step 5: Use artifact key to sign data
testData := []byte("This is test artifact data")
dataSig, err := SignData(*artifactKey, testData)
require.NoError(t, err)
assert.NotEmpty(t, dataSig)
// Step 6: Verify the artifact data signature
dataSigParsed, err := ParseSignature(dataSig)
require.NoError(t, err)
err = ValidateArtifact([]PublicKey{artifactPubKey}, testData, *dataSigParsed)
assert.NoError(t, err, "Artifact data signature should be valid")
}
func TestRootKey_ExpiredKeyWorkflow(t *testing.T) {
// Generate a root key that expires very soon
rootKey, _, _, err := GenerateRootKey(1 * time.Millisecond)
require.NoError(t, err)
// Wait for expiration
time.Sleep(10 * time.Millisecond)
// Try to generate artifact key with expired root key
_, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired")
}

View File

@@ -1,24 +0,0 @@
package reposign
import (
"encoding/json"
"time"
)
// Signature contains a signature with associated Metadata
type Signature struct {
Signature []byte `json:"signature"`
Timestamp time.Time `json:"timestamp"`
KeyID KeyID `json:"key_id"`
Algorithm string `json:"algorithm"` // "ed25519"
HashAlgo string `json:"hash_algo"` // "blake2s" or sha512
}
func ParseSignature(data []byte) (*Signature, error) {
var signature Signature
if err := json.Unmarshal(data, &signature); err != nil {
return nil, err
}
return &signature, nil
}

View File

@@ -1,277 +0,0 @@
package reposign
import (
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseSignature_Valid(t *testing.T) {
timestamp := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
keyID, err := ParseKeyID("0123456789abcdef")
require.NoError(t, err)
signatureData := []byte{0x01, 0x02, 0x03, 0x04}
jsonData, err := json.Marshal(Signature{
Signature: signatureData,
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
})
require.NoError(t, err)
sig, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.NotNil(t, sig)
assert.Equal(t, signatureData, sig.Signature)
assert.Equal(t, timestamp.Unix(), sig.Timestamp.Unix())
assert.Equal(t, keyID, sig.KeyID)
assert.Equal(t, "ed25519", sig.Algorithm)
assert.Equal(t, "blake2s", sig.HashAlgo)
}
func TestParseSignature_InvalidJSON(t *testing.T) {
invalidJSON := []byte(`{invalid json}`)
sig, err := ParseSignature(invalidJSON)
assert.Error(t, err)
assert.Nil(t, sig)
}
func TestParseSignature_EmptyData(t *testing.T) {
emptyJSON := []byte(`{}`)
sig, err := ParseSignature(emptyJSON)
require.NoError(t, err)
assert.NotNil(t, sig)
assert.Empty(t, sig.Signature)
assert.True(t, sig.Timestamp.IsZero())
assert.Equal(t, KeyID{}, sig.KeyID)
assert.Empty(t, sig.Algorithm)
assert.Empty(t, sig.HashAlgo)
}
func TestParseSignature_MissingFields(t *testing.T) {
// JSON with only some fields
partialJSON := []byte(`{
"signature": "AQIDBA==",
"algorithm": "ed25519"
}`)
sig, err := ParseSignature(partialJSON)
require.NoError(t, err)
assert.NotNil(t, sig)
assert.NotEmpty(t, sig.Signature)
assert.Equal(t, "ed25519", sig.Algorithm)
assert.True(t, sig.Timestamp.IsZero())
}
func TestSignature_MarshalUnmarshal_Roundtrip(t *testing.T) {
timestamp := time.Date(2024, 6, 20, 14, 45, 30, 0, time.UTC)
keyID, err := ParseKeyID("fedcba9876543210")
require.NoError(t, err)
original := Signature{
Signature: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe},
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "sha512",
}
// Marshal
jsonData, err := json.Marshal(original)
require.NoError(t, err)
// Unmarshal
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
// Verify
assert.Equal(t, original.Signature, parsed.Signature)
assert.Equal(t, original.Timestamp.Unix(), parsed.Timestamp.Unix())
assert.Equal(t, original.KeyID, parsed.KeyID)
assert.Equal(t, original.Algorithm, parsed.Algorithm)
assert.Equal(t, original.HashAlgo, parsed.HashAlgo)
}
func TestSignature_NilSignatureBytes(t *testing.T) {
timestamp := time.Now().UTC()
keyID, err := ParseKeyID("0011223344556677")
require.NoError(t, err)
sig := Signature{
Signature: nil,
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.Nil(t, parsed.Signature)
}
func TestSignature_LargeSignature(t *testing.T) {
timestamp := time.Now().UTC()
keyID, err := ParseKeyID("aabbccddeeff0011")
require.NoError(t, err)
// Create a large signature (64 bytes for ed25519)
largeSignature := make([]byte, 64)
for i := range largeSignature {
largeSignature[i] = byte(i)
}
sig := Signature{
Signature: largeSignature,
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.Equal(t, largeSignature, parsed.Signature)
}
func TestSignature_WithDifferentHashAlgorithms(t *testing.T) {
tests := []struct {
name string
hashAlgo string
}{
{"blake2s", "blake2s"},
{"sha512", "sha512"},
{"sha256", "sha256"},
{"empty", ""},
}
keyID, err := ParseKeyID("1122334455667788")
require.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sig := Signature{
Signature: []byte{0x01, 0x02},
Timestamp: time.Now().UTC(),
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: tt.hashAlgo,
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.Equal(t, tt.hashAlgo, parsed.HashAlgo)
})
}
}
func TestSignature_TimestampPrecision(t *testing.T) {
// Test that timestamp preserves precision through JSON marshaling
timestamp := time.Date(2024, 3, 15, 10, 30, 45, 123456789, time.UTC)
keyID, err := ParseKeyID("8877665544332211")
require.NoError(t, err)
sig := Signature{
Signature: []byte{0xaa, 0xbb},
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
// JSON timestamps typically have second or millisecond precision
// so we check that at least seconds match
assert.Equal(t, timestamp.Unix(), parsed.Timestamp.Unix())
}
func TestParseSignature_MalformedKeyID(t *testing.T) {
// Test with a malformed KeyID field
malformedJSON := []byte(`{
"signature": "AQID",
"timestamp": "2024-01-15T10:30:00Z",
"key_id": "invalid_keyid_format",
"algorithm": "ed25519",
"hash_algo": "blake2s"
}`)
// This should fail since "invalid_keyid_format" is not a valid KeyID
sig, err := ParseSignature(malformedJSON)
assert.Error(t, err)
assert.Nil(t, sig)
}
func TestParseSignature_InvalidTimestamp(t *testing.T) {
// Test with an invalid timestamp format
invalidTimestampJSON := []byte(`{
"signature": "AQID",
"timestamp": "not-a-timestamp",
"key_id": "0123456789abcdef",
"algorithm": "ed25519",
"hash_algo": "blake2s"
}`)
sig, err := ParseSignature(invalidTimestampJSON)
assert.Error(t, err)
assert.Nil(t, sig)
}
func TestSignature_ZeroKeyID(t *testing.T) {
// Test with a zero KeyID
sig := Signature{
Signature: []byte{0x01, 0x02, 0x03},
Timestamp: time.Now().UTC(),
KeyID: KeyID{},
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.Equal(t, KeyID{}, parsed.KeyID)
}
func TestParseSignature_ExtraFields(t *testing.T) {
// JSON with extra fields that should be ignored
jsonWithExtra := []byte(`{
"signature": "AQIDBA==",
"timestamp": "2024-01-15T10:30:00Z",
"key_id": "0123456789abcdef",
"algorithm": "ed25519",
"hash_algo": "blake2s",
"extra_field": "should be ignored",
"another_extra": 12345
}`)
sig, err := ParseSignature(jsonWithExtra)
require.NoError(t, err)
assert.NotNil(t, sig)
assert.NotEmpty(t, sig.Signature)
assert.Equal(t, "ed25519", sig.Algorithm)
assert.Equal(t, "blake2s", sig.HashAlgo)
}

View File

@@ -1,187 +0,0 @@
package reposign
import (
"context"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
)
const (
artifactPubKeysFileName = "artifact-key-pub.pem"
artifactPubKeysSigFileName = "artifact-key-pub.pem.sig"
revocationFileName = "revocation-list.json"
revocationSignFileName = "revocation-list.json.sig"
keySizeLimit = 5 * 1024 * 1024 //5MB
signatureLimit = 1024
revocationLimit = 10 * 1024 * 1024
)
type ArtifactVerify struct {
rootKeys []PublicKey
keysBaseURL *url.URL
revocationList *RevocationList
}
func NewArtifactVerify(keysBaseURL string) (*ArtifactVerify, error) {
allKeys, err := loadEmbeddedPublicKeys()
if err != nil {
return nil, err
}
return newArtifactVerify(keysBaseURL, allKeys)
}
func newArtifactVerify(keysBaseURL string, allKeys []PublicKey) (*ArtifactVerify, error) {
ku, err := url.Parse(keysBaseURL)
if err != nil {
return nil, fmt.Errorf("invalid keys base URL %q: %v", keysBaseURL, err)
}
a := &ArtifactVerify{
rootKeys: allKeys,
keysBaseURL: ku,
}
return a, nil
}
func (a *ArtifactVerify) Verify(ctx context.Context, version string, artifactFile string) error {
version = strings.TrimPrefix(version, "v")
revocationList, err := a.loadRevocationList(ctx)
if err != nil {
return fmt.Errorf("failed to load revocation list: %v", err)
}
a.revocationList = revocationList
artifactPubKeys, err := a.loadArtifactKeys(ctx)
if err != nil {
return fmt.Errorf("failed to load artifact keys: %v", err)
}
signature, err := a.loadArtifactSignature(ctx, version, artifactFile)
if err != nil {
return fmt.Errorf("failed to download signature file for: %s, %v", filepath.Base(artifactFile), err)
}
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
log.Errorf("failed to read artifact file: %v", err)
return fmt.Errorf("failed to read artifact file: %w", err)
}
if err := ValidateArtifact(artifactPubKeys, artifactData, *signature); err != nil {
return fmt.Errorf("failed to validate artifact: %v", err)
}
return nil
}
func (a *ArtifactVerify) loadRevocationList(ctx context.Context) (*RevocationList, error) {
downloadURL := a.keysBaseURL.JoinPath("keys", revocationFileName).String()
data, err := downloader.DownloadToMemory(ctx, downloadURL, revocationLimit)
if err != nil {
log.Debugf("failed to download revocation list '%s': %s", downloadURL, err)
return nil, err
}
downloadURL = a.keysBaseURL.JoinPath("keys", revocationSignFileName).String()
sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
if err != nil {
log.Debugf("failed to download revocation list '%s': %s", downloadURL, err)
return nil, err
}
signature, err := ParseSignature(sigData)
if err != nil {
log.Debugf("failed to parse revocation list signature: %s", err)
return nil, err
}
return ValidateRevocationList(a.rootKeys, data, *signature)
}
func (a *ArtifactVerify) loadArtifactKeys(ctx context.Context) ([]PublicKey, error) {
downloadURL := a.keysBaseURL.JoinPath("keys", artifactPubKeysFileName).String()
log.Debugf("starting downloading artifact keys from: %s", downloadURL)
data, err := downloader.DownloadToMemory(ctx, downloadURL, keySizeLimit)
if err != nil {
log.Debugf("failed to download artifact keys: %s", err)
return nil, err
}
downloadURL = a.keysBaseURL.JoinPath("keys", artifactPubKeysSigFileName).String()
log.Debugf("start downloading signature of artifact pub key from: %s", downloadURL)
sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
if err != nil {
log.Debugf("failed to download signature of public keys: %s", err)
return nil, err
}
signature, err := ParseSignature(sigData)
if err != nil {
log.Debugf("failed to parse signature of public keys: %s", err)
return nil, err
}
return ValidateArtifactKeys(a.rootKeys, data, *signature, a.revocationList)
}
func (a *ArtifactVerify) loadArtifactSignature(ctx context.Context, version string, artifactFile string) (*Signature, error) {
artifactFile = filepath.Base(artifactFile)
downloadURL := a.keysBaseURL.JoinPath("tag", "v"+version, artifactFile+".sig").String()
data, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
if err != nil {
log.Debugf("failed to download artifact signature: %s", err)
return nil, err
}
signature, err := ParseSignature(data)
if err != nil {
log.Debugf("failed to parse artifact signature: %s", err)
return nil, err
}
return signature, nil
}
func loadEmbeddedPublicKeys() ([]PublicKey, error) {
files, err := embeddedCerts.ReadDir(embeddedCertsDir)
if err != nil {
return nil, fmt.Errorf("failed to read embedded certs: %w", err)
}
var allKeys []PublicKey
for _, file := range files {
if file.IsDir() {
continue
}
data, err := embeddedCerts.ReadFile(embeddedCertsDir + "/" + file.Name())
if err != nil {
return nil, fmt.Errorf("failed to read cert file %s: %w", file.Name(), err)
}
keys, err := parsePublicKeyBundle(data, tagRootPublic)
if err != nil {
return nil, fmt.Errorf("failed to parse cert %s: %w", file.Name(), err)
}
allKeys = append(allKeys, keys...)
}
if len(allKeys) == 0 {
return nil, fmt.Errorf("no valid public keys found in embedded certs")
}
return allKeys, nil
}

View File

@@ -1,528 +0,0 @@
package reposign
import (
"context"
"crypto/ed25519"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test ArtifactVerify construction
func TestArtifactVerify_Construction(t *testing.T) {
// Generate test root key
rootKey, _, rootPubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey, _, err := parsePublicKey(rootPubPEM, tagRootPublic)
require.NoError(t, err)
keysBaseURL := "http://localhost:8080/artifact-signatures"
av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey})
require.NoError(t, err)
assert.NotNil(t, av)
assert.NotEmpty(t, av.rootKeys)
assert.Equal(t, keysBaseURL, av.keysBaseURL.String())
// Verify root key structure
assert.NotEmpty(t, av.rootKeys[0].Key)
assert.Equal(t, rootKey.Metadata.ID, av.rootKeys[0].Metadata.ID)
assert.False(t, av.rootKeys[0].Metadata.CreatedAt.IsZero())
}
func TestArtifactVerify_MultipleRootKeys(t *testing.T) {
// Generate multiple test root keys
rootKey1, _, rootPubPEM1, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey1, _, err := parsePublicKey(rootPubPEM1, tagRootPublic)
require.NoError(t, err)
rootKey2, _, rootPubPEM2, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey2, _, err := parsePublicKey(rootPubPEM2, tagRootPublic)
require.NoError(t, err)
keysBaseURL := "http://localhost:8080/artifact-signatures"
av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey1, rootPubKey2})
assert.NoError(t, err)
assert.Len(t, av.rootKeys, 2)
assert.NotEqual(t, rootKey1.Metadata.ID, rootKey2.Metadata.ID)
}
// Test Verify workflow with mock HTTP server
func TestArtifactVerify_FullWorkflow(t *testing.T) {
// Create temporary test directory
tempDir := t.TempDir()
// Step 1: Generate root key
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
// Step 2: Generate artifact key
artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
require.NoError(t, err)
// Step 3: Create revocation list
revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
// Step 4: Bundle artifact keys
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
require.NoError(t, err)
// Step 5: Create test artifact
artifactPath := filepath.Join(tempDir, "test-artifact.bin")
artifactData := []byte("This is test artifact data for verification")
err = os.WriteFile(artifactPath, artifactData, 0644)
require.NoError(t, err)
// Step 6: Sign artifact
artifactSigData, err := SignData(*artifactKey, artifactData)
require.NoError(t, err)
// Step 7: Setup mock HTTP server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifacts/v1.0.0/test-artifact.bin":
_, _ = w.Write(artifactData)
case "/artifact-signatures/tag/v1.0.0/test-artifact.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
// Step 8: Create ArtifactVerify with test root key
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
// Step 9: Verify artifact
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
assert.NoError(t, err)
}
func TestArtifactVerify_InvalidRevocationList(t *testing.T) {
tempDir := t.TempDir()
artifactPath := filepath.Join(tempDir, "test.bin")
err := os.WriteFile(artifactPath, []byte("test"), 0644)
require.NoError(t, err)
// Setup server with invalid revocation list
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write([]byte("invalid data"))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to load revocation list")
}
func TestArtifactVerify_MissingArtifactFile(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
// Create revocation list
revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
require.NoError(t, err)
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
require.NoError(t, err)
// Create signature for non-existent file
testData := []byte("test")
artifactSigData, err := SignData(*artifactKey, testData)
require.NoError(t, err)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifact-signatures/tag/v1.0.0/missing.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", "file.bin")
assert.Error(t, err)
}
func TestArtifactVerify_ServerUnavailable(t *testing.T) {
tempDir := t.TempDir()
artifactPath := filepath.Join(tempDir, "test.bin")
err := os.WriteFile(artifactPath, []byte("test"), 0644)
require.NoError(t, err)
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
// Use URL that doesn't exist
av, err := newArtifactVerify("http://localhost:19999/keys", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err = av.Verify(ctx, "1.0.0", artifactPath)
assert.Error(t, err)
}
func TestArtifactVerify_ContextCancellation(t *testing.T) {
tempDir := t.TempDir()
artifactPath := filepath.Join(tempDir, "test.bin")
err := os.WriteFile(artifactPath, []byte("test"), 0644)
require.NoError(t, err)
// Create a server that delays response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(500 * time.Millisecond)
_, _ = w.Write([]byte("data"))
}))
defer server.Close()
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL, []PublicKey{rootPubKey})
require.NoError(t, err)
// Create context that cancels quickly
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err = av.Verify(ctx, "1.0.0", artifactPath)
assert.Error(t, err)
}
func TestArtifactVerify_WithRevocation(t *testing.T) {
tempDir := t.TempDir()
// Generate root key
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
// Generate two artifact keys
artifactKey1, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1)
require.NoError(t, err)
_, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2)
require.NoError(t, err)
// Create revocation list with first key revoked
emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
parsedRevocation, err := ParseRevocationList(emptyRevocation)
require.NoError(t, err)
revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration)
require.NoError(t, err)
// Bundle both keys
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2})
require.NoError(t, err)
// Create artifact signed by revoked key
artifactPath := filepath.Join(tempDir, "test.bin")
artifactData := []byte("test data")
err = os.WriteFile(artifactPath, artifactData, 0644)
require.NoError(t, err)
artifactSigData, err := SignData(*artifactKey1, artifactData)
require.NoError(t, err)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
// Should fail because the signing key is revoked
assert.Error(t, err)
assert.Contains(t, err.Error(), "no signing Key found")
}
func TestArtifactVerify_ValidWithSecondKey(t *testing.T) {
tempDir := t.TempDir()
// Generate root key
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
// Generate two artifact keys
_, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1)
require.NoError(t, err)
artifactKey2, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2)
require.NoError(t, err)
// Create revocation list with first key revoked
emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
parsedRevocation, err := ParseRevocationList(emptyRevocation)
require.NoError(t, err)
revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration)
require.NoError(t, err)
// Bundle both keys
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2})
require.NoError(t, err)
// Create artifact signed by second key (not revoked)
artifactPath := filepath.Join(tempDir, "test.bin")
artifactData := []byte("test data")
err = os.WriteFile(artifactPath, artifactData, 0644)
require.NoError(t, err)
artifactSigData, err := SignData(*artifactKey2, artifactData)
require.NoError(t, err)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
// Should succeed because second key is not revoked
assert.NoError(t, err)
}
func TestArtifactVerify_TamperedArtifact(t *testing.T) {
tempDir := t.TempDir()
// Generate root key and artifact key
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
require.NoError(t, err)
// Create revocation list
revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
// Bundle keys
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
require.NoError(t, err)
// Sign original data
originalData := []byte("original data")
artifactSigData, err := SignData(*artifactKey, originalData)
require.NoError(t, err)
// Write tampered data to file
artifactPath := filepath.Join(tempDir, "test.bin")
tamperedData := []byte("tampered data")
err = os.WriteFile(artifactPath, tamperedData, 0644)
require.NoError(t, err)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
// Should fail because artifact was tampered
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to validate artifact")
}
// Test URL validation
func TestArtifactVerify_URLParsing(t *testing.T) {
tests := []struct {
name string
keysBaseURL string
expectError bool
}{
{
name: "Valid HTTP URL",
keysBaseURL: "http://example.com/artifact-signatures",
expectError: false,
},
{
name: "Valid HTTPS URL",
keysBaseURL: "https://example.com/artifact-signatures",
expectError: false,
},
{
name: "URL with port",
keysBaseURL: "http://localhost:8080/artifact-signatures",
expectError: false,
},
{
name: "Invalid URL",
keysBaseURL: "://invalid",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := newArtifactVerify(tt.keysBaseURL, nil)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -1,11 +0,0 @@
package updatemanager
import v "github.com/hashicorp/go-version"
type UpdateInterface interface {
StopWatch()
SetDaemonVersion(newVersion string) bool
SetOnUpdateListener(updateFn func())
LatestVersion() *v.Version
StartFetcher()
}

View File

@@ -131,7 +131,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}

View File

@@ -51,7 +51,7 @@
</ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" TerminateProcess="0" />
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" />

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v6.33.1
// protoc v6.32.1
// source: daemon.proto
package proto
@@ -893,7 +893,6 @@ type UpRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
AutoUpdate *bool `protobuf:"varint,3,opt,name=autoUpdate,proto3,oneof" json:"autoUpdate,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -942,13 +941,6 @@ func (x *UpRequest) GetUsername() string {
return ""
}
func (x *UpRequest) GetAutoUpdate() bool {
if x != nil && x.AutoUpdate != nil {
return *x.AutoUpdate
}
return false
}
type UpResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -1265,7 +1257,6 @@ type GetConfigResponse struct {
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"`
DisableFirewall bool `protobuf:"varint,27,opt,name=disable_firewall,json=disableFirewall,proto3" json:"disable_firewall,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -1482,13 +1473,6 @@ func (x *GetConfigResponse) GetSshJWTCacheTTL() int32 {
return 0
}
func (x *GetConfigResponse) GetDisableFirewall() bool {
if x != nil {
return x.DisableFirewall
}
return false
}
// PeerState contains the latest state of a peer
type PeerState struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -5372,94 +5356,6 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
return 0
}
type InstallerResultRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *InstallerResultRequest) Reset() {
*x = InstallerResultRequest{}
mi := &file_daemon_proto_msgTypes[79]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *InstallerResultRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*InstallerResultRequest) ProtoMessage() {}
func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[79]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{79}
}
type InstallerResultResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
ErrorMsg string `protobuf:"bytes,2,opt,name=errorMsg,proto3" json:"errorMsg,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *InstallerResultResponse) Reset() {
*x = InstallerResultResponse{}
mi := &file_daemon_proto_msgTypes[80]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *InstallerResultResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*InstallerResultResponse) ProtoMessage() {}
func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[80]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{80}
}
func (x *InstallerResultResponse) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
func (x *InstallerResultResponse) GetErrorMsg() string {
if x != nil {
return x.ErrorMsg
}
return ""
}
type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -5470,7 +5366,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[82]
mi := &file_daemon_proto_msgTypes[80]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5482,7 +5378,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[82]
mi := &file_daemon_proto_msgTypes[80]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5606,16 +5502,12 @@ const file_daemon_proto_rawDesc = "" +
"\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" +
"\bhostname\x18\x02 \x01(\tR\bhostname\",\n" +
"\x14WaitSSOLoginResponse\x12\x14\n" +
"\x05email\x18\x01 \x01(\tR\x05email\"\xa4\x01\n" +
"\x05email\x18\x01 \x01(\tR\x05email\"p\n" +
"\tUpRequest\x12%\n" +
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12#\n" +
"\n" +
"autoUpdate\x18\x03 \x01(\bH\x02R\n" +
"autoUpdate\x88\x01\x01B\x0e\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" +
"\t_usernameB\r\n" +
"\v_autoUpdate\"\f\n" +
"\t_username\"\f\n" +
"\n" +
"UpResponse\"\xa1\x01\n" +
"\rStatusRequest\x12,\n" +
@@ -5633,7 +5525,7 @@ const file_daemon_proto_rawDesc = "" +
"\fDownResponse\"P\n" +
"\x10GetConfigRequest\x12 \n" +
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
"\busername\x18\x02 \x01(\tR\busername\"\x86\t\n" +
"\busername\x18\x02 \x01(\tR\busername\"\xdb\b\n" +
"\x11GetConfigResponse\x12$\n" +
"\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" +
"\n" +
@@ -5664,8 +5556,7 @@ const file_daemon_proto_rawDesc = "" +
"\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" +
"\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" +
"\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\x12&\n" +
"\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\x12)\n" +
"\x10disable_firewall\x18\x1b \x01(\bR\x0fdisableFirewall\"\xfe\x05\n" +
"\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\"\xfe\x05\n" +
"\tPeerState\x12\x0e\n" +
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" +
@@ -6002,11 +5893,7 @@ const file_daemon_proto_rawDesc = "" +
"\x14WaitJWTTokenResponse\x12\x14\n" +
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" +
"\x16InstallerResultRequest\"O\n" +
"\x17InstallerResultResponse\x12\x18\n" +
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" +
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" +
"\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" +
@@ -6015,7 +5902,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xb4\x13\n" +
"\x05TRACE\x10\a2\xdb\x12\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -6051,8 +5938,7 @@ const file_daemon_proto_rawDesc = "" +
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" +
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00B\bZ\x06/protob\x06proto3"
var (
file_daemon_proto_rawDescOnce sync.Once
@@ -6067,7 +5953,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 82)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
@@ -6152,21 +6038,19 @@ var file_daemon_proto_goTypes = []any{
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
(*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse
nil, // 85: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range
nil, // 87: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 88: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp
nil, // 83: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 84: daemon.PortInfo.Range
nil, // 85: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 86: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 87: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
86, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
87, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
87, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
86, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
@@ -6177,8 +6061,8 @@ var file_daemon_proto_depIdxs = []int32{
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
83, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
84, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -6189,10 +6073,10 @@ var file_daemon_proto_depIdxs = []int32{
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
87, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
85, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
86, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
@@ -6227,42 +6111,40 @@ var file_daemon_proto_depIdxs = []int32{
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
67, // [67:100] is the sub-list for method output_type
34, // [34:67] is the sub-list for method input_type
8, // 66: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 67: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 68: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 69: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 70: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 71: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 72: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 73: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 74: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 75: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 76: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 77: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 78: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 79: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 80: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 81: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 82: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 83: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 84: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 85: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 86: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 87: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 88: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 89: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 90: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 91: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 92: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 93: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 94: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 95: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 96: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
6, // 97: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
66, // [66:98] is the sub-list for method output_type
34, // [34:66] is the sub-list for method input_type
34, // [34:34] is the sub-list for extension type_name
34, // [34:34] is the sub-list for extension extendee
0, // [0:34] is the sub-list for field type_name
@@ -6292,7 +6174,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 4,
NumMessages: 84,
NumMessages: 82,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -95,8 +95,6 @@ service DaemonService {
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
}
@@ -217,7 +215,6 @@ message WaitSSOLoginResponse {
message UpRequest {
optional string profileName = 1;
optional string username = 2;
optional bool autoUpdate = 3;
}
message UpResponse {}
@@ -303,8 +300,6 @@ message GetConfigResponse {
bool disableSSHAuth = 25;
int32 sshJWTCacheTTL = 26;
bool disable_firewall = 27;
}
// PeerState contains the latest state of a peer
@@ -777,11 +772,3 @@ message WaitJWTTokenResponse {
// expiration time in seconds
int64 expiresIn = 3;
}
message InstallerResultRequest {
}
message InstallerResultResponse {
bool success = 1;
string errorMsg = 2;
}

View File

@@ -71,7 +71,6 @@ type DaemonServiceClient interface {
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
}
type daemonServiceClient struct {
@@ -393,15 +392,6 @@ func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifec
return out, nil
}
func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) {
out := new(InstallerResultResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetInstallerResult", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -459,7 +449,6 @@ type DaemonServiceServer interface {
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -563,9 +552,6 @@ func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTo
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
}
func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1158,24 +1144,6 @@ func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Conte
return interceptor(ctx, in, info, handler)
}
func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InstallerResultRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).GetInstallerResult(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/GetInstallerResult",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetInstallerResult(ctx, req.(*InstallerResultRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1307,10 +1275,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "NotifyOSLifecycle",
Handler: _DaemonService_NotifyOSLifecycle_Handler,
},
{
MethodName: "GetInstallerResult",
Handler: _DaemonService_GetInstallerResult_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -14,4 +14,4 @@ cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
cd "$old_pwd"
cd "$old_pwd"

View File

@@ -192,7 +192,7 @@ func (s *Server) Start() error {
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan)
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
return nil
}
@@ -223,7 +223,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) {
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() {
s.mutex.Lock()
s.clientRunning = false
@@ -231,7 +231,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
}()
if s.config.DisableAutoConnect {
if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil {
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
@@ -260,8 +260,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
}()
runOperation := func() error {
err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan)
doInitialAutoUpdate = false
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
@@ -729,12 +728,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
var doAutoUpdate bool
if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate {
doAutoUpdate = true
}
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
return s.waitForUp(callerCtx)
}
@@ -1386,7 +1380,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
disableClientRoutes := cfg.DisableClientRoutes
disableServerRoutes := cfg.DisableServerRoutes
blockLANAccess := cfg.BlockLANAccess
disableFirewall := cfg.DisableFirewall
enableSSHRoot := false
if cfg.EnableSSHRoot != nil {
@@ -1443,7 +1436,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
SshJWTCacheTTL: sshJWTCacheTTL,
DisableFirewall: disableFirewall,
}, nil
}
@@ -1547,9 +1539,9 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
return features, nil
}
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error {
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
return err

View File

@@ -112,7 +112,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, nil, nil)
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}

View File

@@ -1,30 +0,0 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/client/proto"
)
func (s *Server) GetInstallerResult(ctx context.Context, _ *proto.InstallerResultRequest) (*proto.InstallerResultResponse, error) {
inst := installer.New()
dir := inst.TempDir()
rh := installer.NewResultHandler(dir)
result, err := rh.Watch(ctx)
if err != nil {
log.Errorf("failed to watch update result: %v", err)
return &proto.InstallerResultResponse{
Success: false,
ErrorMsg: err.Error(),
}, nil
}
return &proto.InstallerResultResponse{
Success: result.Success,
ErrorMsg: result.Error,
}, nil
}

View File

@@ -1,184 +0,0 @@
package auth
import (
"errors"
"fmt"
"sync"
log "github.com/sirupsen/logrus"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
const (
// DefaultUserIDClaim is the default JWT claim used to extract user IDs
DefaultUserIDClaim = "sub"
// Wildcard is a special user ID that matches all users
Wildcard = "*"
)
var (
ErrEmptyUserID = errors.New("JWT user ID is empty")
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
)
// Authorizer handles SSH fine-grained access control authorization
type Authorizer struct {
// UserIDClaim is the JWT claim to extract the user ID from
userIDClaim string
// authorizedUsers is a list of hashed user IDs authorized to access this peer
authorizedUsers []sshuserhash.UserIDHash
// machineUsers maps OS login usernames to lists of authorized user indexes
machineUsers map[string][]uint32
// mu protects the list of users
mu sync.RWMutex
}
// Config contains configuration for the SSH authorizer
type Config struct {
// UserIDClaim is the JWT claim to extract the user ID from (e.g., "sub", "email")
UserIDClaim string
// AuthorizedUsers is a list of hashed user IDs (FNV-1a 64-bit) authorized to access this peer
AuthorizedUsers []sshuserhash.UserIDHash
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
// If a user wants to login as a specific OS user, their index must be in the corresponding list
MachineUsers map[string][]uint32
}
// NewAuthorizer creates a new SSH authorizer with empty configuration
func NewAuthorizer() *Authorizer {
a := &Authorizer{
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
}
return a
}
// Update updates the authorizer configuration with new values
func (a *Authorizer) Update(config *Config) {
a.mu.Lock()
defer a.mu.Unlock()
if config == nil {
// Clear authorization
a.userIDClaim = DefaultUserIDClaim
a.authorizedUsers = []sshuserhash.UserIDHash{}
a.machineUsers = make(map[string][]uint32)
log.Info("SSH authorization cleared")
return
}
userIDClaim := config.UserIDClaim
if userIDClaim == "" {
userIDClaim = DefaultUserIDClaim
}
a.userIDClaim = userIDClaim
// Store authorized users list
a.authorizedUsers = config.AuthorizedUsers
// Store machine users mapping
machineUsers := make(map[string][]uint32)
for osUser, indexes := range config.MachineUsers {
if len(indexes) > 0 {
machineUsers[osUser] = indexes
}
}
a.machineUsers = machineUsers
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
len(config.AuthorizedUsers), len(machineUsers))
}
// Authorize validates if a user is authorized to login as the specified OS user
// Returns nil if authorized, or an error describing why authorization failed
func (a *Authorizer) Authorize(jwtUserID, osUsername string) error {
if jwtUserID == "" {
log.Warnf("SSH auth denied: JWT user ID is empty for OS user '%s'", osUsername)
return ErrEmptyUserID
}
// Hash the JWT user ID for comparison
hashedUserID, err := sshuserhash.HashUserID(jwtUserID)
if err != nil {
log.Errorf("SSH auth denied: failed to hash user ID '%s' for OS user '%s': %v", jwtUserID, osUsername, err)
return fmt.Errorf("failed to hash user ID: %w", err)
}
a.mu.RLock()
defer a.mu.RUnlock()
// Find the index of this user in the authorized list
userIndex, found := a.findUserIndex(hashedUserID)
if !found {
log.Warnf("SSH auth denied: user '%s' (hash: %s) not in authorized list for OS user '%s'", jwtUserID, hashedUserID, osUsername)
return ErrUserNotAuthorized
}
return a.checkMachineUserMapping(jwtUserID, osUsername, userIndex)
}
// checkMachineUserMapping validates if a user's index is authorized for the specified OS user
// Checks wildcard mapping first, then specific OS user mappings
func (a *Authorizer) checkMachineUserMapping(jwtUserID, osUsername string, userIndex int) error {
// If wildcard exists and user's index is in the wildcard list, allow access to any OS user
if wildcardIndexes, hasWildcard := a.machineUsers[Wildcard]; hasWildcard {
if a.isIndexInList(uint32(userIndex), wildcardIndexes) {
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' via wildcard (index: %d)", jwtUserID, osUsername, userIndex)
return nil
}
}
// Check for specific OS username mapping
allowedIndexes, hasMachineUserMapping := a.machineUsers[osUsername]
if !hasMachineUserMapping {
// No mapping for this OS user - deny by default (fail closed)
log.Warnf("SSH auth denied: no machine user mapping for OS user '%s' (JWT user: %s)", osUsername, jwtUserID)
return ErrNoMachineUserMapping
}
// Check if user's index is in the allowed indexes for this specific OS user
if !a.isIndexInList(uint32(userIndex), allowedIndexes) {
log.Warnf("SSH auth denied: user '%s' not mapped to OS user '%s' (user index: %d)", jwtUserID, osUsername, userIndex)
return ErrUserNotMappedToOSUser
}
log.Infof("SSH auth granted: user '%s' authorized for OS user '%s' (index: %d)", jwtUserID, osUsername, userIndex)
return nil
}
// GetUserIDClaim returns the JWT claim name used to extract user IDs
func (a *Authorizer) GetUserIDClaim() string {
a.mu.RLock()
defer a.mu.RUnlock()
return a.userIDClaim
}
// findUserIndex finds the index of a hashed user ID in the authorized users list
// Returns the index and true if found, 0 and false if not found
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
for i, id := range a.authorizedUsers {
if id == hashedUserID {
return i, true
}
}
return 0, false
}
// isIndexInList checks if an index exists in a list of indexes
func (a *Authorizer) isIndexInList(index uint32, indexes []uint32) bool {
for _, idx := range indexes {
if idx == index {
return true
}
}
return false
}

View File

@@ -1,612 +0,0 @@
package auth
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/sshauth"
)
func TestAuthorizer_Authorize_UserNotInList(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list with one user
authorizedUserHash, err := sshauth.HashUserID("authorized-user")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{authorizedUserHash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Try to authorize a different user
err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Authorize_UserInList_NoMachineUserRestrictions(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{}, // Empty = deny all (fail closed)
}
authorizer.Update(config)
// All attempts should fail when no machine user mappings exist (fail closed)
err = authorizer.Authorize("user1", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Allowed(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1}, // user1 and user2 can access root
"postgres": {1, 2}, // user2 and user3 can access postgres
"admin": {0}, // only user1 can access admin
},
}
authorizer.Update(config)
// user1 (index 0) should access root and admin
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user1", "admin")
assert.NoError(t, err)
// user2 (index 1) should access root and postgres
err = authorizer.Authorize("user2", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
// user3 (index 2) should access postgres
err = authorizer.Authorize("user3", "postgres")
assert.NoError(t, err)
}
func TestAuthorizer_Authorize_UserInList_WithMachineUserMapping_Denied(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1}, // user1 and user2 can access root
"postgres": {1, 2}, // user2 and user3 can access postgres
"admin": {0}, // only user1 can access admin
},
}
authorizer.Update(config)
// user1 (index 0) should NOT access postgres
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 (index 1) should NOT access admin
err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access root
err = authorizer.Authorize("user3", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user3 (index 2) should NOT access admin
err = authorizer.Authorize("user3", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
}
func TestAuthorizer_Authorize_UserInList_OSUserNotInMapping(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0}, // only root is mapped
},
}
authorizer.Update(config)
// user1 should NOT access an unmapped OS user (fail closed)
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Authorize_EmptyJWTUserID(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users list
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Empty user ID should fail
err = authorizer.Authorize("", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrEmptyUserID)
}
func TestAuthorizer_Authorize_MultipleUsersInList(t *testing.T) {
authorizer := NewAuthorizer()
// Set up multiple authorized users
userHashes := make([]sshauth.UserIDHash, 10)
for i := 0; i < 10; i++ {
hash, err := sshauth.HashUserID("user" + string(rune('0'+i)))
require.NoError(t, err)
userHashes[i] = hash
}
// Create machine user mapping for all users
rootIndexes := make([]uint32, 10)
for i := 0; i < 10; i++ {
rootIndexes[i] = uint32(i)
}
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: userHashes,
MachineUsers: map[string][]uint32{
"root": rootIndexes,
},
}
authorizer.Update(config)
// All users should be authorized for root
for i := 0; i < 10; i++ {
err := authorizer.Authorize("user"+string(rune('0'+i)), "root")
assert.NoError(t, err, "user%d should be authorized", i)
}
// User not in list should fail
err := authorizer.Authorize("unknown-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Update_ClearsConfiguration(t *testing.T) {
authorizer := NewAuthorizer()
// Set up initial configuration
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{"root": {0}},
}
authorizer.Update(config)
// user1 should be authorized
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// Clear configuration
authorizer.Update(nil)
// user1 should no longer be authorized
err = authorizer.Authorize("user1", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Update_EmptyMachineUsersListEntries(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
// Machine users with empty index lists should be filtered out
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0},
"postgres": {}, // empty list - should be filtered out
"admin": nil, // nil list - should be filtered out
},
}
authorizer.Update(config)
// root should work
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// postgres should fail (no mapping)
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// admin should fail (no mapping)
err = authorizer.Authorize("user1", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_CustomUserIDClaim(t *testing.T) {
authorizer := NewAuthorizer()
// Set up with custom user ID claim
user1Hash, err := sshauth.HashUserID("user@example.com")
require.NoError(t, err)
config := &Config{
UserIDClaim: "email",
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"root": {0},
},
}
authorizer.Update(config)
// Verify the custom claim is set
assert.Equal(t, "email", authorizer.GetUserIDClaim())
// Authorize with email as user ID
err = authorizer.Authorize("user@example.com", "root")
assert.NoError(t, err)
}
func TestAuthorizer_DefaultUserIDClaim(t *testing.T) {
authorizer := NewAuthorizer()
// Verify default claim
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
assert.Equal(t, "sub", authorizer.GetUserIDClaim())
// Set up with empty user ID claim (should use default)
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
config := &Config{
UserIDClaim: "", // empty - should use default
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{},
}
authorizer.Update(config)
// Should fall back to default
assert.Equal(t, DefaultUserIDClaim, authorizer.GetUserIDClaim())
}
func TestAuthorizer_MachineUserMapping_LargeIndexes(t *testing.T) {
authorizer := NewAuthorizer()
// Create a large authorized users list
const numUsers = 1000
userHashes := make([]sshauth.UserIDHash, numUsers)
for i := 0; i < numUsers; i++ {
hash, err := sshauth.HashUserID("user" + string(rune(i)))
require.NoError(t, err)
userHashes[i] = hash
}
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: userHashes,
MachineUsers: map[string][]uint32{
"root": {0, 500, 999}, // first, middle, and last user
},
}
authorizer.Update(config)
// First user should have access
err := authorizer.Authorize("user"+string(rune(0)), "root")
assert.NoError(t, err)
// Middle user should have access
err = authorizer.Authorize("user"+string(rune(500)), "root")
assert.NoError(t, err)
// Last user should have access
err = authorizer.Authorize("user"+string(rune(999)), "root")
assert.NoError(t, err)
// User not in mapping should NOT have access
err = authorizer.Authorize("user"+string(rune(100)), "root")
assert.Error(t, err)
}
func TestAuthorizer_ConcurrentAuthorization(t *testing.T) {
authorizer := NewAuthorizer()
// Set up authorized users
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"root": {0, 1},
},
}
authorizer.Update(config)
// Test concurrent authorization calls (should be safe to read concurrently)
const numGoroutines = 100
errChan := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(idx int) {
user := "user1"
if idx%2 == 0 {
user = "user2"
}
err := authorizer.Authorize(user, "root")
errChan <- err
}(i)
}
// Wait for all goroutines to complete and collect errors
for i := 0; i < numGoroutines; i++ {
err := <-errChan
assert.NoError(t, err)
}
}
func TestAuthorizer_Wildcard_AllowsAllAuthorizedUsers(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
user3Hash, err := sshauth.HashUserID("user3")
require.NoError(t, err)
// Configure with wildcard - all authorized users can access any OS user
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash, user3Hash},
MachineUsers: map[string][]uint32{
"*": {0, 1, 2}, // wildcard with all user indexes
},
}
authorizer.Update(config)
// All authorized users should be able to access any OS user
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
err = authorizer.Authorize("user3", "admin")
assert.NoError(t, err)
err = authorizer.Authorize("user1", "ubuntu")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "nginx")
assert.NoError(t, err)
err = authorizer.Authorize("user3", "docker")
assert.NoError(t, err)
}
func TestAuthorizer_Wildcard_UnauthorizedUserStillDenied(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
// Configure with wildcard
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash},
MachineUsers: map[string][]uint32{
"*": {0},
},
}
authorizer.Update(config)
// user1 should have access
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// Unauthorized user should still be denied even with wildcard
err = authorizer.Authorize("unauthorized-user", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized)
}
func TestAuthorizer_Wildcard_TakesPrecedenceOverSpecificMappings(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure with both wildcard and specific mappings
// Wildcard takes precedence for users in the wildcard index list
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"*": {0, 1}, // wildcard for both users
"root": {0}, // specific mapping that would normally restrict to user1 only
},
}
authorizer.Update(config)
// Both users should be able to access root via wildcard (takes precedence over specific mapping)
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "root")
assert.NoError(t, err)
// Both users should be able to access any other OS user via wildcard
err = authorizer.Authorize("user1", "postgres")
assert.NoError(t, err)
err = authorizer.Authorize("user2", "admin")
assert.NoError(t, err)
}
func TestAuthorizer_NoWildcard_SpecificMappingsOnly(t *testing.T) {
authorizer := NewAuthorizer()
user1Hash, err := sshauth.HashUserID("user1")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure WITHOUT wildcard - only specific mappings
config := &Config{
UserIDClaim: DefaultUserIDClaim,
AuthorizedUsers: []sshauth.UserIDHash{user1Hash, user2Hash},
MachineUsers: map[string][]uint32{
"root": {0}, // only user1
"postgres": {1}, // only user2
},
}
authorizer.Update(config)
// user1 can access root
err = authorizer.Authorize("user1", "root")
assert.NoError(t, err)
// user2 can access postgres
err = authorizer.Authorize("user2", "postgres")
assert.NoError(t, err)
// user1 cannot access postgres
err = authorizer.Authorize("user1", "postgres")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// user2 cannot access root
err = authorizer.Authorize("user2", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotMappedToOSUser)
// Neither can access unmapped OS users
err = authorizer.Authorize("user1", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "admin")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
}
func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
// This test covers the scenario where wildcard exists with limited indexes.
// Only users whose indexes are in the wildcard list can access any OS user via wildcard.
// Other users can only access OS users they are explicitly mapped to.
authorizer := NewAuthorizer()
// Create two authorized user hashes (simulating the base64-encoded hashes in the config)
wasmHash, err := sshauth.HashUserID("wasm")
require.NoError(t, err)
user2Hash, err := sshauth.HashUserID("user2")
require.NoError(t, err)
// Configure with wildcard having only index 0, and specific mappings for other OS users
config := &Config{
UserIDClaim: "sub",
AuthorizedUsers: []sshauth.UserIDHash{wasmHash, user2Hash},
MachineUsers: map[string][]uint32{
"*": {0}, // wildcard with only index 0 - only wasm has wildcard access
"alice": {1}, // specific mapping for user2
"bob": {1}, // specific mapping for user2
},
}
authorizer.Update(config)
// wasm (index 0) should access any OS user via wildcard
err = authorizer.Authorize("wasm", "root")
assert.NoError(t, err, "wasm should access root via wildcard")
err = authorizer.Authorize("wasm", "alice")
assert.NoError(t, err, "wasm should access alice via wildcard")
err = authorizer.Authorize("wasm", "bob")
assert.NoError(t, err, "wasm should access bob via wildcard")
err = authorizer.Authorize("wasm", "postgres")
assert.NoError(t, err, "wasm should access postgres via wildcard")
// user2 (index 1) should only access alice and bob (explicitly mapped), NOT root or postgres
err = authorizer.Authorize("user2", "alice")
assert.NoError(t, err, "user2 should access alice via explicit mapping")
err = authorizer.Authorize("user2", "bob")
assert.NoError(t, err, "user2 should access bob via explicit mapping")
err = authorizer.Authorize("user2", "root")
assert.Error(t, err, "user2 should NOT access root (not in wildcard indexes)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
err = authorizer.Authorize("user2", "postgres")
assert.Error(t, err, "user2 should NOT access postgres (not explicitly mapped)")
assert.ErrorIs(t, err, ErrNoMachineUserMapping)
// Unauthorized user should still be denied
err = authorizer.Authorize("user3", "root")
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
}

View File

@@ -27,11 +27,9 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestMain(m *testing.M) {
@@ -139,21 +137,6 @@ func TestSSHProxy_Connect(t *testing.T) {
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
// Configure SSH authorization for the test user
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0}, // Index 0 in AuthorizedUsers
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
@@ -167,10 +150,10 @@ func TestSSHProxy_Connect(t *testing.T) {
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
validToken := generateValidJWT(t, privateKey, issuer, audience)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
@@ -364,12 +347,12 @@ func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": user,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}

View File

@@ -23,12 +23,10 @@ import (
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestJWTEnforcement(t *testing.T) {
@@ -579,22 +577,6 @@ func TestJWTAuthentication(t *testing.T) {
tc.setupServer(server)
}
// Always set up authorization for test-user to ensure tests fail at JWT validation stage
testUserHash, err := sshuserhash.HashUserID("test-user")
require.NoError(t, err)
// Get current OS username for machine user mapping
currentUser := testutil.GetTestUsername(t)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
currentUser: {0}, // Allow test-user (index 0) to access current OS user
},
}
server.UpdateSSHAuth(authConfig)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())

View File

@@ -21,7 +21,6 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
@@ -139,8 +138,6 @@ type Server struct {
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
authorizer *sshauth.Authorizer
suSupportsPty bool
loginIsUtilLinux bool
}
@@ -182,7 +179,6 @@ func New(config *Config) *Server {
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT,
authorizer: sshauth.NewAuthorizer(), // Initialize with empty config
}
return s
@@ -324,19 +320,6 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.wgAddress = addr
}
// UpdateSSHAuth updates the SSH fine-grained access control configuration
// This should be called when network map updates include new SSH auth configuration
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
s.mu.Lock()
defer s.mu.Unlock()
// Reset JWT validator/extractor to pick up new userIDClaim
s.jwtValidator = nil
s.jwtExtractor = nil
s.authorizer.Update(config)
}
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
func (s *Server) ensureJWTValidator() error {
s.mu.RLock()
@@ -345,7 +328,6 @@ func (s *Server) ensureJWTValidator() error {
return nil
}
config := s.jwtConfig
authorizer := s.authorizer
s.mu.RUnlock()
if config == nil {
@@ -361,16 +343,9 @@ func (s *Server) ensureJWTValidator() error {
true,
)
// Use custom userIDClaim from authorizer if available
extractorOptions := []jwt.ClaimsExtractorOption{
extractor := jwt.NewClaimsExtractor(
jwt.WithAudience(config.Audience),
}
if authorizer.GetUserIDClaim() != "" {
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
log.Debugf("Using custom user ID claim: %s", authorizer.GetUserIDClaim())
}
extractor := jwt.NewClaimsExtractor(extractorOptions...)
)
s.mu.Lock()
defer s.mu.Unlock()
@@ -518,41 +493,29 @@ func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]int
}
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
osUsername := ctx.User()
remoteAddr := ctx.RemoteAddr()
if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", osUsername, remoteAddr, err)
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
token, err := s.validateJWTToken(password)
if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", osUsername, remoteAddr, err)
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
userAuth, err := s.extractAndValidateUser(token)
if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", osUsername, remoteAddr, err)
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
s.mu.RLock()
authorizer := s.authorizer
s.mu.RUnlock()
if err := authorizer.Authorize(userAuth.UserId, osUsername); err != nil {
log.Warnf("SSH authorization denied for user %s (JWT user ID: %s) from %s: %v", osUsername, userAuth.UserId, remoteAddr, err)
return false
}
key := newAuthKey(osUsername, remoteAddr)
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
s.mu.Lock()
s.pendingAuthJWT[key] = userAuth.UserId
s.mu.Unlock()
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", osUsername, userAuth.UserId, remoteAddr)
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
return true
}

View File

@@ -120,6 +120,26 @@ func (i *Info) SetFlags(
}
}
func (i *Info) CopyFlagsFrom(other *Info) {
i.SetFlags(
other.RosenpassEnabled,
other.RosenpassPermissive,
&other.ServerSSHAllowed,
other.DisableClientRoutes,
other.DisableServerRoutes,
other.DisableDNS,
other.DisableFirewall,
other.BlockLANAccess,
other.BlockInbound,
other.LazyConnectionEnabled,
&other.EnableSSHRoot,
&other.EnableSSHSFTP,
&other.EnableSSHLocalPortForwarding,
&other.EnableSSHRemotePortForwarding,
&other.DisableSSHAuth,
)
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx)

View File

@@ -8,6 +8,90 @@ import (
"google.golang.org/grpc/metadata"
)
func TestInfo_CopyFlagsFrom(t *testing.T) {
origin := &Info{}
serverSSHAllowed := true
enableSSHRoot := true
enableSSHSFTP := false
enableSSHLocalPortForwarding := true
enableSSHRemotePortForwarding := false
disableSSHAuth := true
origin.SetFlags(
true, // RosenpassEnabled
false, // RosenpassPermissive
&serverSSHAllowed,
true, // DisableClientRoutes
false, // DisableServerRoutes
true, // DisableDNS
false, // DisableFirewall
true, // BlockLANAccess
false, // BlockInbound
true, // LazyConnectionEnabled
&enableSSHRoot,
&enableSSHSFTP,
&enableSSHLocalPortForwarding,
&enableSSHRemotePortForwarding,
&disableSSHAuth,
)
got := &Info{}
got.CopyFlagsFrom(origin)
if got.RosenpassEnabled != true {
t.Fatalf("RosenpassEnabled not copied: got %v", got.RosenpassEnabled)
}
if got.RosenpassPermissive != false {
t.Fatalf("RosenpassPermissive not copied: got %v", got.RosenpassPermissive)
}
if got.ServerSSHAllowed != true {
t.Fatalf("ServerSSHAllowed not copied: got %v", got.ServerSSHAllowed)
}
if got.DisableClientRoutes != true {
t.Fatalf("DisableClientRoutes not copied: got %v", got.DisableClientRoutes)
}
if got.DisableServerRoutes != false {
t.Fatalf("DisableServerRoutes not copied: got %v", got.DisableServerRoutes)
}
if got.DisableDNS != true {
t.Fatalf("DisableDNS not copied: got %v", got.DisableDNS)
}
if got.DisableFirewall != false {
t.Fatalf("DisableFirewall not copied: got %v", got.DisableFirewall)
}
if got.BlockLANAccess != true {
t.Fatalf("BlockLANAccess not copied: got %v", got.BlockLANAccess)
}
if got.BlockInbound != false {
t.Fatalf("BlockInbound not copied: got %v", got.BlockInbound)
}
if got.LazyConnectionEnabled != true {
t.Fatalf("LazyConnectionEnabled not copied: got %v", got.LazyConnectionEnabled)
}
if got.EnableSSHRoot != true {
t.Fatalf("EnableSSHRoot not copied: got %v", got.EnableSSHRoot)
}
if got.EnableSSHSFTP != false {
t.Fatalf("EnableSSHSFTP not copied: got %v", got.EnableSSHSFTP)
}
if got.EnableSSHLocalPortForwarding != true {
t.Fatalf("EnableSSHLocalPortForwarding not copied: got %v", got.EnableSSHLocalPortForwarding)
}
if got.EnableSSHRemotePortForwarding != false {
t.Fatalf("EnableSSHRemotePortForwarding not copied: got %v", got.EnableSSHRemotePortForwarding)
}
if got.DisableSSHAuth != true {
t.Fatalf("DisableSSHAuth not copied: got %v", got.DisableSSHAuth)
}
// ensure CopyFlagsFrom does not touch unrelated fields
origin.Hostname = "host-a"
got.Hostname = "host-b"
got.CopyFlagsFrom(origin)
if got.Hostname != "host-b" {
t.Fatalf("CopyFlagsFrom should not overwrite non-flag fields, got Hostname=%q", got.Hostname)
}
}
func Test_LocalWTVersion(t *testing.T) {
got := GetInfo(context.TODO())
want := "development"

View File

@@ -34,7 +34,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
protobuf "google.golang.org/protobuf/proto"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
@@ -44,6 +43,7 @@ import (
"github.com/netbirdio/netbird/client/ui/desktop"
"github.com/netbirdio/netbird/client/ui/event"
"github.com/netbirdio/netbird/client/ui/process"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
@@ -87,24 +87,22 @@ func main() {
// Create the service client (this also builds the settings or networks UI if requested).
client := newServiceClient(&newServiceClientArgs{
addr: flags.daemonAddr,
logFile: logFile,
app: a,
showSettings: flags.showSettings,
showNetworks: flags.showNetworks,
showLoginURL: flags.showLoginURL,
showDebug: flags.showDebug,
showProfiles: flags.showProfiles,
showQuickActions: flags.showQuickActions,
showUpdate: flags.showUpdate,
showUpdateVersion: flags.showUpdateVersion,
addr: flags.daemonAddr,
logFile: logFile,
app: a,
showSettings: flags.showSettings,
showNetworks: flags.showNetworks,
showLoginURL: flags.showLoginURL,
showDebug: flags.showDebug,
showProfiles: flags.showProfiles,
showQuickActions: flags.showQuickActions,
})
// Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set.
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate {
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions {
a.Run()
return
}
@@ -130,17 +128,15 @@ func main() {
}
type cliFlags struct {
daemonAddr string
showSettings bool
showNetworks bool
showProfiles bool
showDebug bool
showLoginURL bool
showQuickActions bool
errorMsg string
saveLogsInFile bool
showUpdate bool
showUpdateVersion string
daemonAddr string
showSettings bool
showNetworks bool
showProfiles bool
showDebug bool
showLoginURL bool
showQuickActions bool
errorMsg string
saveLogsInFile bool
}
// parseFlags reads and returns all needed command-line flags.
@@ -160,8 +156,6 @@ func parseFlags() *cliFlags {
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to")
flag.Parse()
return &flags
}
@@ -312,8 +306,6 @@ type serviceClient struct {
daemonVersion string
updateIndicationLock sync.Mutex
isUpdateIconActive bool
settingsEnabled bool
profilesEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
@@ -327,8 +319,6 @@ type serviceClient struct {
mExitNodeDeselectAll *systray.MenuItem
logFile string
wLoginURL fyne.Window
wUpdateProgress fyne.Window
updateContextCancel context.CancelFunc
connectCancel context.CancelFunc
}
@@ -339,17 +329,15 @@ type menuHandler struct {
}
type newServiceClientArgs struct {
addr string
logFile string
app fyne.App
showSettings bool
showNetworks bool
showDebug bool
showLoginURL bool
showProfiles bool
showQuickActions bool
showUpdate bool
showUpdateVersion string
addr string
logFile string
app fyne.App
showSettings bool
showNetworks bool
showDebug bool
showLoginURL bool
showProfiles bool
showQuickActions bool
}
// newServiceClient instance constructor
@@ -367,7 +355,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
showAdvancedSettings: args.showSettings,
showNetworks: args.showNetworks,
update: version.NewUpdateAndStart("nb/client-ui"),
update: version.NewUpdate("nb/client-ui"),
}
s.eventHandler = newEventHandler(s)
@@ -387,8 +375,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
s.showProfilesUI()
case args.showQuickActions:
s.showQuickActionsUI()
case args.showUpdate:
s.showUpdateProgress(ctx, args.showUpdateVersion)
}
return s
@@ -828,7 +814,7 @@ func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.Log
return nil
}
func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) error {
func (s *serviceClient) menuUpClick(ctx context.Context) error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
@@ -850,9 +836,7 @@ func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) e
return nil
}
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{
AutoUpdate: protobuf.Bool(wannaAutoUpdate),
}); err != nil {
if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("start connection: %w", err)
}
@@ -909,7 +893,7 @@ func (s *serviceClient) updateStatus() error {
var systrayIconState bool
switch {
case status.Status == string(internal.StatusConnected) && !s.mUp.Disabled():
case status.Status == string(internal.StatusConnected):
s.connected = true
s.sendNotification = true
if s.isUpdateIconActive {
@@ -923,7 +907,6 @@ func (s *serviceClient) updateStatus() error {
s.mUp.Disable()
s.mDown.Enable()
s.mNetworks.Enable()
s.mExitNode.Enable()
go s.updateExitNodes()
systrayIconState = true
case status.Status == string(internal.StatusConnecting):
@@ -1114,26 +1097,6 @@ func (s *serviceClient) onTrayReady() {
s.updateExitNodes()
}
})
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
// todo use new Category
if windowAction, ok := event.Metadata["progress_window"]; ok {
targetVersion, ok := event.Metadata["version"]
if !ok {
targetVersion = "unknown"
}
log.Debugf("window action: %v", windowAction)
if windowAction == "show" {
if s.updateContextCancel != nil {
s.updateContextCancel()
s.updateContextCancel = nil
}
subCtx, cancel := context.WithCancel(s.ctx)
go s.eventHandler.runSelfCommand(subCtx, "update", "--update-version", targetVersion)
s.updateContextCancel = cancel
}
}
})
go s.eventManager.Start(s.ctx)
go s.eventHandler.listen(s.ctx)
@@ -1277,22 +1240,19 @@ func (s *serviceClient) checkAndUpdateFeatures() {
return
}
s.updateIndicationLock.Lock()
defer s.updateIndicationLock.Unlock()
// Update settings menu based on current features
settingsEnabled := features == nil || !features.DisableUpdateSettings
if s.settingsEnabled != settingsEnabled {
s.settingsEnabled = settingsEnabled
s.setSettingsEnabled(settingsEnabled)
if features != nil && features.DisableUpdateSettings {
s.setSettingsEnabled(false)
} else {
s.setSettingsEnabled(true)
}
// Update profile menu based on current features
if s.mProfile != nil {
profilesEnabled := features == nil || !features.DisableProfiles
if s.profilesEnabled != profilesEnabled {
s.profilesEnabled = profilesEnabled
s.mProfile.setEnabled(profilesEnabled)
if features != nil && features.DisableProfiles {
s.mProfile.setEnabled(false)
} else {
s.mProfile.setEnabled(true)
}
}
}

View File

@@ -80,7 +80,7 @@ func (h *eventHandler) handleConnectClick() {
go func() {
defer connectCancel()
if err := h.client.menuUpClick(connectCtx, true); err != nil {
if err := h.client.menuUpClick(connectCtx); err != nil {
st, ok := status.FromError(err)
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
log.Debugf("connect operation cancelled by user")
@@ -185,7 +185,7 @@ func (h *eventHandler) handleAdvancedSettingsClick() {
go func() {
defer h.client.mAdvancedSettings.Enable()
defer h.client.getSrvConfig()
h.runSelfCommand(h.client.ctx, "settings")
h.runSelfCommand(h.client.ctx, "settings", "true")
}()
}
@@ -193,7 +193,7 @@ func (h *eventHandler) handleCreateDebugBundleClick() {
h.client.mCreateDebugBundle.Disable()
go func() {
defer h.client.mCreateDebugBundle.Enable()
h.runSelfCommand(h.client.ctx, "debug")
h.runSelfCommand(h.client.ctx, "debug", "true")
}()
}
@@ -217,7 +217,7 @@ func (h *eventHandler) handleNetworksClick() {
h.client.mNetworks.Disable()
go func() {
defer h.client.mNetworks.Enable()
h.runSelfCommand(h.client.ctx, "networks")
h.runSelfCommand(h.client.ctx, "networks", "true")
}()
}
@@ -237,21 +237,17 @@ func (h *eventHandler) updateConfigWithErr() error {
return nil
}
func (h *eventHandler) runSelfCommand(ctx context.Context, command string, args ...string) {
func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) {
proc, err := os.Executable()
if err != nil {
log.Errorf("error getting executable path: %v", err)
return
}
// Build the full command arguments
cmdArgs := []string{
fmt.Sprintf("--%s=true", command),
cmd := exec.CommandContext(ctx, proc,
fmt.Sprintf("--%s=%s", command, arg),
fmt.Sprintf("--daemon-addr=%s", h.client.addr),
}
cmdArgs = append(cmdArgs, args...)
cmd := exec.CommandContext(ctx, proc, cmdArgs...)
)
if out := h.client.attachOutput(cmd); out != nil {
defer func() {
@@ -261,17 +257,17 @@ func (h *eventHandler) runSelfCommand(ctx context.Context, command string, args
}()
}
log.Printf("running command: %s", cmd.String())
log.Printf("running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, h.client.addr)
if err := cmd.Run(); err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
log.Printf("command '%s' failed with exit code %d", cmd.String(), exitErr.ExitCode())
log.Printf("command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode())
}
return
}
log.Printf("command '%s' completed successfully", cmd.String())
log.Printf("command '%s %s' completed successfully", command, arg)
}
func (h *eventHandler) logout(ctx context.Context) error {

View File

@@ -397,7 +397,7 @@ type profileMenu struct {
logoutSubItem *subItem
profilesState []Profile
downClickCallback func() error
upClickCallback func(context.Context, bool) error
upClickCallback func(context.Context) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func()
app fyne.App
@@ -411,7 +411,7 @@ type newProfileMenuArgs struct {
profileMenuItem *systray.MenuItem
emailMenuItem *systray.MenuItem
downClickCallback func() error
upClickCallback func(context.Context, bool) error
upClickCallback func(context.Context) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func()
app fyne.App
@@ -579,7 +579,7 @@ func (p *profileMenu) refresh() {
connectCtx, connectCancel := context.WithCancel(p.ctx)
p.serviceClient.connectCancel = connectCancel
if err := p.upClickCallback(connectCtx, false); err != nil {
if err := p.upClickCallback(connectCtx); err != nil {
log.Errorf("failed to handle up click after switching profile: %v", err)
}

View File

@@ -267,7 +267,7 @@ func (s *serviceClient) showQuickActionsUI() {
connCmd := connectCommand{
connectClient: func() error {
return s.menuUpClick(s.ctx, false)
return s.menuUpClick(s.ctx)
},
}

View File

@@ -1,140 +0,0 @@
//go:build !(linux && 386)
package main
import (
"context"
"errors"
"fmt"
"strings"
"time"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
)
func (s *serviceClient) showUpdateProgress(ctx context.Context, version string) {
log.Infof("show installer progress window: %s", version)
s.wUpdateProgress = s.app.NewWindow("Automatically updating client")
statusLabel := widget.NewLabel("Updating...")
infoLabel := widget.NewLabel(fmt.Sprintf("Your client version is older than the auto-update version set in Management.\nUpdating client to: %s.", version))
content := container.NewVBox(infoLabel, statusLabel)
s.wUpdateProgress.SetContent(content)
s.wUpdateProgress.CenterOnScreen()
s.wUpdateProgress.SetFixedSize(true)
s.wUpdateProgress.SetCloseIntercept(func() {
// this is empty to lock window until result known
})
s.wUpdateProgress.RequestFocus()
s.wUpdateProgress.Show()
updateWindowCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
// Initialize dot updater
updateText := dotUpdater()
// Channel to receive the result from RPC call
resultErrCh := make(chan error, 1)
resultOkCh := make(chan struct{}, 1)
// Start RPC call in background
go func() {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Infof("backend not reachable, upgrade in progress: %v", err)
close(resultOkCh)
return
}
resp, err := conn.GetInstallerResult(updateWindowCtx, &proto.InstallerResultRequest{})
if err != nil {
log.Infof("backend stopped responding, upgrade in progress: %v", err)
close(resultOkCh)
return
}
if !resp.Success {
resultErrCh <- mapInstallError(resp.ErrorMsg)
return
}
// Success
close(resultOkCh)
}()
// Update UI with dots and wait for result
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
defer cancel()
// allow closing update window after 10 sec
timerResetCloseInterceptor := time.NewTimer(10 * time.Second)
defer timerResetCloseInterceptor.Stop()
for {
select {
case <-updateWindowCtx.Done():
s.showInstallerResult(statusLabel, updateWindowCtx.Err())
return
case err := <-resultErrCh:
s.showInstallerResult(statusLabel, err)
return
case <-resultOkCh:
log.Info("backend exited, upgrade in progress, closing all UI")
killParentUIProcess()
s.app.Quit()
return
case <-ticker.C:
statusLabel.SetText(updateText())
case <-timerResetCloseInterceptor.C:
s.wUpdateProgress.SetCloseIntercept(nil)
}
}
}()
}
func (s *serviceClient) showInstallerResult(statusLabel *widget.Label, err error) {
s.wUpdateProgress.SetCloseIntercept(nil)
switch {
case errors.Is(err, context.DeadlineExceeded):
log.Warn("update watcher timed out")
statusLabel.SetText("Update timed out. Please try again.")
case errors.Is(err, context.Canceled):
log.Info("update watcher canceled")
statusLabel.SetText("Update canceled.")
case err != nil:
log.Errorf("update failed: %v", err)
statusLabel.SetText("Update failed: " + err.Error())
default:
s.wUpdateProgress.Close()
}
}
// dotUpdater returns a closure that cycles through dots for a loading animation.
func dotUpdater() func() string {
dotCount := 0
return func() string {
dotCount = (dotCount + 1) % 4
return fmt.Sprintf("%s%s", "Updating", strings.Repeat(".", dotCount))
}
}
func mapInstallError(msg string) error {
msg = strings.ToLower(strings.TrimSpace(msg))
switch {
case strings.Contains(msg, "deadline exceeded"), strings.Contains(msg, "timeout"):
return context.DeadlineExceeded
case strings.Contains(msg, "canceled"), strings.Contains(msg, "cancelled"):
return context.Canceled
case msg == "":
return errors.New("unknown update error")
default:
return errors.New(msg)
}
}

View File

@@ -1,7 +0,0 @@
//go:build !windows && !(linux && 386)
package main
func killParentUIProcess() {
// No-op on non-Windows platforms
}

View File

@@ -1,44 +0,0 @@
//go:build windows
package main
import (
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
nbprocess "github.com/netbirdio/netbird/client/ui/process"
)
// killParentUIProcess finds and kills the parent systray UI process on Windows.
// This is a workaround in case the MSI installer fails to properly terminate the UI process.
// The installer should handle this via util:CloseApplication with TerminateProcess, but this
// provides an additional safety mechanism to ensure the UI is closed before the upgrade proceeds.
func killParentUIProcess() {
pid, running, err := nbprocess.IsAnotherProcessRunning()
if err != nil {
log.Warnf("failed to check for parent UI process: %v", err)
return
}
if !running {
log.Debug("no parent UI process found to kill")
return
}
log.Infof("killing parent UI process (PID: %d)", pid)
// Open the process with terminate rights
handle, err := windows.OpenProcess(windows.PROCESS_TERMINATE, false, uint32(pid))
if err != nil {
log.Warnf("failed to open parent process %d: %v", pid, err)
return
}
defer func() {
_ = windows.CloseHandle(handle)
}()
// Terminate the process with exit code 0
if err := windows.TerminateProcess(handle, 0); err != nil {
log.Warnf("failed to terminate parent process %d: %v", pid, err)
}
}

View File

@@ -60,7 +60,14 @@ func (hook ContextHook) Fire(entry *logrus.Entry) error {
entry.Data["context"] = source
addFields(entry)
switch source {
case HTTPSource:
addHTTPFields(entry)
case GRPCSource:
addGRPCFields(entry)
case SystemSource:
addSystemFields(entry)
}
return nil
}
@@ -92,7 +99,7 @@ func (hook ContextHook) parseSrc(filePath string) string {
return fmt.Sprintf("%s/%s", pkg, file)
}
func addFields(entry *logrus.Entry) {
func addHTTPFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
@@ -102,6 +109,30 @@ func addFields(entry *logrus.Entry) {
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
}
func addGRPCFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
}
func addSystemFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}

View File

@@ -53,8 +53,7 @@ services:
command: [
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
"--log-file", "console",
"--port", "80"
"--log-file", "console"
]
# Relay

View File

@@ -178,7 +178,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
@@ -225,7 +224,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -321,7 +320,6 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
postureChecks, err := c.getPeerPostureChecks(account, peerId)
if err != nil {
@@ -340,7 +338,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -396,23 +394,26 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
if isRequiresApproval {
emptyMap := &types.NetworkMap{
Network: network.Copy(),
}
return peer, emptyMap, nil, 0, nil
}
var (
account *types.Account
err error
)
if clientSerial > 0 && clientSerial == network.CurrentSerial() {
log.WithContext(ctx).Debugf("client serial %d matches current serial, skipping network map calculation", clientSerial)
return peer, nil, nil, 0, nil
}
var account *types.Account
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID)
} else {
@@ -447,7 +448,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics, account.GetActiveGroupUsers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -813,7 +814,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -24,7 +24,7 @@ type Controller interface {
UpdateAccountPeers(ctx context.Context, accountID string) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer, clientSerial uint64) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)

View File

@@ -113,9 +113,9 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal
}
// GetValidatedPeerWithMap mocks base method.
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer, clientSerial uint64) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p, clientSerial)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret2, _ := ret[2].([]*posture.Checks)
@@ -125,9 +125,9 @@ func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequires
}
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p, clientSerial any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p, clientSerial)
}
// OnPeerConnected mocks base method.

View File

@@ -158,7 +158,5 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
}
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}

View File

@@ -10,9 +10,9 @@ import (
"slices"
"time"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@@ -180,7 +180,7 @@ func unaryInterceptor(
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := xid.New().String()
reqID := uuid.New().String()
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, hook.GRPCSource)
//nolint
@@ -194,7 +194,7 @@ func streamInterceptor(
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := xid.New().String()
reqID := uuid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), hook.ExecutionContextKey, hook.GRPCSource)

View File

@@ -183,7 +183,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
s.update = version.NewUpdateAndStart("nb/management")
s.update = version.NewUpdate("nb/management")
s.update.SetDaemonVersion(version.NetbirdVersion())
s.update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())

Some files were not shown because too many files have changed in this diff Show More