[client, management] auto-update (#4732)

This commit is contained in:
Zoltan Papp
2025-12-19 19:57:39 +01:00
committed by GitHub
parent 537151e0f3
commit 011cc81678
87 changed files with 9720 additions and 716 deletions

View File

@@ -122,7 +122,7 @@ func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsRea
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
}
@@ -149,7 +149,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
}

View File

@@ -85,6 +85,9 @@ var (
// Execute executes the root command.
func Execute() error {
if isUpdateBinary() {
return updateCmd.Execute()
}
return rootCmd.Execute()
}

View File

@@ -0,0 +1,176 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
bundlePubKeysRootPrivKeyFile string
bundlePubKeysPubKeyFiles []string
bundlePubKeysFile string
createArtifactKeyRootPrivKeyFile string
createArtifactKeyPrivKeyFile string
createArtifactKeyPubKeyFile string
createArtifactKeyExpiration time.Duration
)
var createArtifactKeyCmd = &cobra.Command{
Use: "create-artifact-key",
Short: "Create a new artifact signing key",
Long: `Generate a new artifact signing key pair signed by the root private key.
The artifact key will be used to sign software artifacts/updates.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if createArtifactKeyExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil {
return fmt.Errorf("failed to create artifact key: %w", err)
}
return nil
},
}
var bundlePubKeysCmd = &cobra.Command{
Use: "bundle-pub-keys",
Short: "Bundle multiple artifact public keys into a signed package",
Long: `Bundle one or more artifact public keys into a signed package using the root private key.
This command is typically used to distribute or authorize a set of valid artifact signing keys.`,
RunE: func(cmd *cobra.Command, args []string) error {
if len(bundlePubKeysPubKeyFiles) == 0 {
return fmt.Errorf("at least one --artifact-pub-key-file must be provided")
}
if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil {
return fmt.Errorf("failed to bundle public keys: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createArtifactKeyCmd)
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved")
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved")
createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)")
if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(fmt.Errorf("mark expiration as required: %w", err))
}
rootCmd.AddCommand(bundlePubKeysCmd)
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle")
bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)")
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved")
if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil {
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
}
if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil {
panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err))
}
}
func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error {
cmd.Println("Creating new artifact signing key...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration)
if err != nil {
return fmt.Errorf("generate artifact key: %w", err)
}
if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err)
}
if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err)
}
signatureFile := artifactPubKeyFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Artifact key created successfully.\n")
cmd.Printf("%s\n", artifactKey.String())
return nil
}
func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error {
cmd.Println("📦 Bundling public keys into signed package...")
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
if err != nil {
return fmt.Errorf("read root private key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles))
for _, pubFile := range artifactPubKeyFiles {
pubPem, err := os.ReadFile(pubFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
pk, err := reposign.ParseArtifactPubKey(pubPem)
if err != nil {
return fmt.Errorf("failed to parse artifact key: %w", err)
}
publicKeys = append(publicKeys, pk)
}
parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys)
if err != nil {
return fmt.Errorf("bundle artifact keys: %w", err)
}
if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil {
return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err)
}
signatureFile := bundlePubKeysFile + ".sig"
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
}
cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles))
return nil
}

View File

@@ -0,0 +1,276 @@
package main
import (
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY"
)
var (
signArtifactPrivKeyFile string
signArtifactArtifactFile string
verifyArtifactPubKeyFile string
verifyArtifactFile string
verifyArtifactSignatureFile string
verifyArtifactKeyPubKeyFile string
verifyArtifactKeyRootPubKeyFile string
verifyArtifactKeySignatureFile string
verifyArtifactKeyRevocationFile string
)
var signArtifactCmd = &cobra.Command{
Use: "sign-artifact",
Short: "Sign an artifact using an artifact private key",
Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key.
This command produces a detached signature that can be verified using the corresponding artifact public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil {
return fmt.Errorf("failed to sign artifact: %w", err)
}
return nil
},
}
var verifyArtifactCmd = &cobra.Command{
Use: "verify-artifact",
Short: "Verify an artifact signature using an artifact public key",
Long: `Verify a software artifact signature using the artifact's public key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil {
return fmt.Errorf("failed to verify artifact: %w", err)
}
return nil
},
}
var verifyArtifactKeyCmd = &cobra.Command{
Use: "verify-artifact-key",
Short: "Verify an artifact public key was signed by a root key",
Long: `Verify that an artifact public key (or bundle) was properly signed by a root key.
This validates the chain of trust from the root key to the artifact key.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil {
return fmt.Errorf("failed to verify artifact key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(signArtifactCmd)
rootCmd.AddCommand(verifyArtifactCmd)
rootCmd.AddCommand(verifyArtifactKeyCmd)
signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey))
signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed")
// artifact-file is required, but artifact-key-file can come from env var
if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified")
verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file")
if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
panic(fmt.Errorf("mark artifact-file as required: %w", err))
}
if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file")
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)")
if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil {
panic(fmt.Errorf("mark artifact-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil {
panic(fmt.Errorf("mark root-key-file as required: %w", err))
}
if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil {
panic(fmt.Errorf("mark signature-file as required: %w", err))
}
}
func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error {
cmd.Println("🖋️ Signing artifact...")
// Load private key from env var or file
var privKeyPEM []byte
var err error
if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" {
// Use key from environment variable
privKeyPEM = []byte(envKey)
} else if privKeyFile != "" {
// Fall back to file
privKeyPEM, err = os.ReadFile(privKeyFile)
if err != nil {
return fmt.Errorf("read private key file: %w", err)
}
} else {
return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey)
}
privateKey, err := reposign.ParseArtifactKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact private key: %w", err)
}
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
signature, err := reposign.SignData(privateKey, artifactData)
if err != nil {
return fmt.Errorf("sign artifact: %w", err)
}
sigFile := artifactFile + ".sig"
if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil {
return fmt.Errorf("write signature file (%s): %w", sigFile, err)
}
cmd.Printf("✅ Artifact signed successfully.\n")
cmd.Printf("Signature file: %s\n", sigFile)
return nil
}
func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error {
cmd.Println("🔍 Verifying artifact...")
// Read artifact public key
pubKeyPEM, err := os.ReadFile(pubKeyFile)
if err != nil {
return fmt.Errorf("read public key file: %w", err)
}
publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse artifact public key: %w", err)
}
// Read artifact data
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
return fmt.Errorf("read artifact file: %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate artifact
if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil {
return fmt.Errorf("artifact verification failed: %w", err)
}
cmd.Println("✅ Artifact signature is valid")
cmd.Printf("Artifact: %s\n", artifactFile)
cmd.Printf("Signed by key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
return nil
}
func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error {
cmd.Println("🔍 Verifying artifact key...")
// Read artifact key data
artifactKeyData, err := os.ReadFile(artifactKeyFile)
if err != nil {
return fmt.Errorf("read artifact key file: %w", err)
}
// Read root public key(s)
rootKeyData, err := os.ReadFile(rootKeyFile)
if err != nil {
return fmt.Errorf("read root key file: %w", err)
}
rootPublicKeys, err := parseRootPublicKeys(rootKeyData)
if err != nil {
return fmt.Errorf("failed to parse root public key(s): %w", err)
}
// Read signature
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("read signature file: %w", err)
}
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Read optional revocation list
var revocationList *reposign.RevocationList
if revocationFile != "" {
revData, err := os.ReadFile(revocationFile)
if err != nil {
return fmt.Errorf("read revocation file: %w", err)
}
revocationList, err = reposign.ParseRevocationList(revData)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
}
// Validate artifact key(s)
validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList)
if err != nil {
return fmt.Errorf("artifact key verification failed: %w", err)
}
cmd.Println("✅ Artifact key(s) verified successfully")
cmd.Printf("Signed by root key: %s\n", signature.KeyID)
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys))
for i, key := range validKeys {
cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID)
cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST"))
if !key.Metadata.ExpiresAt.IsZero() {
cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST"))
} else {
cmd.Printf(" Expires: Never\n")
}
}
return nil
}
// parseRootPublicKeys parses a root public key from PEM data
func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) {
key, err := reposign.ParseRootPublicKey(data)
if err != nil {
return nil, err
}
return []reposign.PublicKey{key}, nil
}

21
client/cmd/signer/main.go Normal file
View File

@@ -0,0 +1,21 @@
package main
import (
"os"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "signer",
Short: "A CLI tool for managing cryptographic keys and artifacts",
Long: `signer is a command-line tool that helps you manage
root keys, artifact keys, and revocation lists securely.`,
}
func main() {
if err := rootCmd.Execute(); err != nil {
rootCmd.Println(err)
os.Exit(1)
}
}

View File

@@ -0,0 +1,220 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
const (
defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year
)
var (
keyID string
revocationListFile string
privateRootKeyFile string
publicRootKeyFile string
signatureFile string
expirationDuration time.Duration
)
var createRevocationListCmd = &cobra.Command{
Use: "create-revocation-list",
Short: "Create a new revocation list signed by the private root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile)
},
}
var extendRevocationListCmd = &cobra.Command{
Use: "extend-revocation-list",
Short: "Extend an existing revocation list with a given key ID",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile)
},
}
var verifyRevocationListCmd = &cobra.Command{
Use: "verify-revocation-list",
Short: "Verify a revocation list signature using the public root key",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile)
},
}
func init() {
rootCmd.AddCommand(createRevocationListCmd)
rootCmd.AddCommand(extendRevocationListCmd)
rootCmd.AddCommand(verifyRevocationListCmd)
createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for")
extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
panic(err)
}
verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file")
verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file")
verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file")
if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil {
panic(err)
}
if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil {
panic(err)
}
}
func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration)
if err != nil {
return fmt.Errorf("failed to create revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list created successfully")
return nil
}
func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error {
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read private root key file: %w", err)
}
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse private root key: %w", err)
}
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
rl, err := reposign.ParseRevocationList(rlBytes)
if err != nil {
return fmt.Errorf("failed to parse revocation list: %w", err)
}
kid, err := reposign.ParseKeyID(keyID)
if err != nil {
return fmt.Errorf("invalid key ID: %w", err)
}
newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration)
if err != nil {
return fmt.Errorf("failed to extend revocation list: %w", err)
}
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil {
return fmt.Errorf("failed to write output files: %w", err)
}
cmd.Println("✅ Revocation list extended successfully")
return nil
}
func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error {
// Read revocation list file
rlBytes, err := os.ReadFile(revocationListFile)
if err != nil {
return fmt.Errorf("failed to read revocation list file: %w", err)
}
// Read signature file
sigBytes, err := os.ReadFile(signatureFile)
if err != nil {
return fmt.Errorf("failed to read signature file: %w", err)
}
// Read public root key file
pubKeyPEM, err := os.ReadFile(publicRootKeyFile)
if err != nil {
return fmt.Errorf("failed to read public root key file: %w", err)
}
// Parse public root key
publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse public root key: %w", err)
}
// Parse signature
signature, err := reposign.ParseSignature(sigBytes)
if err != nil {
return fmt.Errorf("failed to parse signature: %w", err)
}
// Validate revocation list
rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature)
if err != nil {
return fmt.Errorf("failed to validate revocation list: %w", err)
}
// Display results
cmd.Println("✅ Revocation list signature is valid")
cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339))
cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339))
cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked))
if len(rl.Revoked) > 0 {
cmd.Println("\nRevoked Keys:")
for keyID, revokedTime := range rl.Revoked {
cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339))
}
}
return nil
}
func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error {
if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil {
return fmt.Errorf("failed to write revocation list file: %w", err)
}
if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil {
return fmt.Errorf("failed to write signature file: %w", err)
}
return nil
}

View File

@@ -0,0 +1,74 @@
package main
import (
"fmt"
"os"
"time"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
var (
privKeyFile string
pubKeyFile string
rootExpiration time.Duration
)
var createRootKeyCmd = &cobra.Command{
Use: "create-root-key",
Short: "Create a new root key pair",
Long: `Create a new root key pair and specify an expiration time for it.`,
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
// Validate expiration
if rootExpiration <= 0 {
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
}
// Run main logic
if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil {
return fmt.Errorf("failed to generate root key: %w", err)
}
return nil
},
}
func init() {
rootCmd.AddCommand(createRootKeyCmd)
createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file")
createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file")
createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)")
if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil {
panic(err)
}
if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil {
panic(err)
}
}
func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error {
rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration)
if err != nil {
return fmt.Errorf("generate root key: %w", err)
}
// Write private key
if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil {
return fmt.Errorf("write private key file (%s): %w", privKeyFile, err)
}
// Write public key
if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil {
return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err)
}
cmd.Printf("%s\n\n", rk.String())
cmd.Printf("✅ Root key pair generated successfully.\n")
return nil
}

View File

@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
connectClient := internal.NewConnectClient(ctx, config, r, false)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)

13
client/cmd/update.go Normal file
View File

@@ -0,0 +1,13 @@
//go:build !windows && !darwin
package cmd
import (
"github.com/spf13/cobra"
)
var updateCmd *cobra.Command
func isUpdateBinary() bool {
return false
}

View File

@@ -0,0 +1,75 @@
//go:build windows || darwin
package cmd
import (
"context"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/util"
)
var (
updateCmd = &cobra.Command{
Use: "update",
Short: "Update the NetBird client application",
RunE: updateFunc,
}
tempDirFlag string
installerFile string
serviceDirFlag string
dryRunFlag bool
)
func init() {
updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir")
updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file")
updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory")
updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes")
}
// isUpdateBinary checks if the current executable is named "update" or "update.exe"
func isUpdateBinary() bool {
// Remove extension for cross-platform compatibility
execPath, err := os.Executable()
if err != nil {
return false
}
baseName := filepath.Base(execPath)
name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
return name == installer.UpdaterBinaryNameWithoutExtension()
}
func updateFunc(cmd *cobra.Command, args []string) error {
if err := setupLogToFile(tempDirFlag); err != nil {
return err
}
log.Infof("updater started: %s", serviceDirFlag)
updater := installer.NewWithDir(tempDirFlag)
if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil {
log.Errorf("failed to update application: %v", err)
return err
}
return nil
}
func setupLogToFile(dir string) error {
logFile := filepath.Join(dir, installer.LogFile)
if _, err := os.Stat(logFile); err == nil {
if err := os.Remove(logFile); err != nil {
log.Errorf("failed to remove existing log file: %v\n", err)
}
}
return util.InitLog(logLevel, util.LogConsole, logFile)
}

View File

@@ -173,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder)
client := internal.NewConnectClient(ctx, c.config, recorder, false)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available

View File

@@ -24,10 +24,14 @@ import (
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -39,11 +43,13 @@ import (
)
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
doInitialAutoUpdate bool
engine *Engine
engineMutex sync.Mutex
persistSyncResponse bool
}
@@ -52,13 +58,15 @@ func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
doInitialAutoUpdate: doInitalAutoUpdate,
engineMutex: sync.Mutex{},
}
}
@@ -162,6 +170,33 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return err
}
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
if !fileExists(mobileDependency.StateFilePath) {
err := createFile(mobileDependency.StateFilePath)
if err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
path = mobileDependency.StateFilePath
}
stateManager := statemanager.New(path)
stateManager.RegisterState(&sshconfig.ShutdownState{})
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
if err == nil {
updateManager.CheckUpdateSuccess(c.ctx)
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
}
}
defer c.statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
@@ -273,7 +308,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
@@ -283,6 +318,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
// AutoUpdate will be true when the user click on "Connect" menu on the UI
if c.doInitialAutoUpdate {
log.Infof("start engine by ui, run auto-update check")
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
c.doInitialAutoUpdate = false
}
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)

View File

@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
@@ -362,6 +363,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add systemd logs: %v", err)
}
if err := g.addUpdateLogs(); err != nil {
log.Errorf("failed to add updater logs: %v", err)
}
return nil
}
@@ -650,6 +655,29 @@ func (g *BundleGenerator) addStateFile() error {
return nil
}
func (g *BundleGenerator) addUpdateLogs() error {
inst := installer.New()
logFiles := inst.LogFiles()
if len(logFiles) == 0 {
return nil
}
log.Infof("adding updater logs")
for _, logFile := range logFiles {
data, err := os.ReadFile(logFile)
if err != nil {
log.Warnf("failed to read update log file %s: %v", logFile, err)
continue
}
baseName := filepath.Base(logFile)
if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil {
return fmt.Errorf("add update log file %s to zip: %w", baseName, err)
}
}
return nil
}
func (g *BundleGenerator) addCorruptedStateFiles() error {
sm := profilemanager.NewServiceManager("")
pattern := sm.GetStatePath()

View File

@@ -42,14 +42,13 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager"
cProto "github.com/netbirdio/netbird/client/proto"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -73,6 +72,7 @@ const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
disableAutoUpdate = "disabled"
)
var ErrResetConnection = fmt.Errorf("reset connection")
@@ -201,6 +201,9 @@ type Engine struct {
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
// auto-update
updateManager *updatemanager.Manager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
@@ -221,17 +224,7 @@ type localIpUpdater interface {
}
// NewEngine creates a new Connection Engine with probes attached
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
@@ -247,28 +240,12 @@ func NewEngine(
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: statusRecorder,
stateManager: stateManager,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
}
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
if !fileExists(mobileDep.StateFilePath) {
err := createFile(mobileDep.StateFilePath)
if err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
path = mobileDep.StateFilePath
}
engine.stateManager = statemanager.New(path)
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine
}
@@ -308,6 +285,10 @@ func (e *Engine) Stop() error {
e.srWatcher.Close()
}
if e.updateManager != nil {
e.updateManager.Stop()
}
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
@@ -541,6 +522,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return nil
}
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.handleAutoUpdateVersion(autoUpdateSettings, true)
}
func (e *Engine) createFirewall() error {
if e.config.DisableFirewall {
log.Infof("firewall is disabled")
@@ -749,6 +737,41 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
return nil
}
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
if autoUpdateSettings == nil {
return
}
disabled := autoUpdateSettings.Version == disableAutoUpdate
// Stop and cleanup if disabled
if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop()
e.updateManager = nil
return
}
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
return
}
// Start manager if needed
if e.updateManager == nil {
log.Infof("starting auto-update manager")
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
if err != nil {
return
}
e.updateManager = updateManager
e.updateManager.Start(e.ctx)
}
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
e.updateManager.SetVersion(autoUpdateSettings.Version)
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -758,6 +781,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return e.ctx.Err()
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())

View File

@@ -253,6 +253,7 @@ func TestEngine_SSH(t *testing.T) {
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil,
nil,
)
engine.dnsServer = &dns.MockServer{
@@ -414,21 +415,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
relayMgr,
&EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
wgIface := &MockWGIface{
NameFunc: func() string { return "utun102" },
@@ -647,7 +640,7 @@ func TestEngine_Sync(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
@@ -812,7 +805,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
@@ -1014,7 +1007,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
@@ -1540,7 +1533,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
e.ctx = ctx
return e, err
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/url"
"os"
"os/user"
"path/filepath"
"reflect"
"runtime"
@@ -165,19 +166,26 @@ func getConfigDir() (string, error) {
if ConfigDirOverride != "" {
return ConfigDirOverride, nil
}
configDir, err := os.UserConfigDir()
base, err := baseConfigDir()
if err != nil {
return "", err
}
configDir = filepath.Join(configDir, "netbird")
if _, err := os.Stat(configDir); os.IsNotExist(err) {
if err := os.MkdirAll(configDir, 0755); err != nil {
return "", err
configDir := filepath.Join(base, "netbird")
if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err
}
return configDir, nil
}
func baseConfigDir() (string, error) {
if runtime.GOOS == "darwin" {
if u, err := user.Current(); err == nil && u.HomeDir != "" {
return filepath.Join(u.HomeDir, "Library", "Application Support"), nil
}
}
return configDir, nil
return os.UserConfigDir()
}
func getConfigDirForUser(username string) (string, error) {

View File

@@ -0,0 +1,35 @@
// Package updatemanager provides automatic update management for the NetBird client.
// It monitors for new versions, handles update triggers from management server directives,
// and orchestrates the download and installation of client updates.
//
// # Overview
//
// The update manager operates as a background service that continuously monitors for
// available updates and automatically initiates the update process when conditions are met.
// It integrates with the installer package to perform the actual installation.
//
// # Update Flow
//
// The complete update process follows these steps:
//
// 1. Manager receives update directive via SetVersion() or detects new version
// 2. Manager validates update should proceed (version comparison, rate limiting)
// 3. Manager publishes "updating" event to status recorder
// 4. Manager persists UpdateState to track update attempt
// 5. Manager downloads installer file (.msi or .exe) to temporary directory
// 6. Manager triggers installation via installer.RunInstallation()
// 7. Installer package handles the actual installation process
// 8. On next startup, CheckUpdateSuccess() verifies update completion
// 9. Manager publishes success/failure event to status recorder
// 10. Manager cleans up UpdateState
//
// # State Management
//
// Update state is persisted across restarts to track update attempts:
//
// - PreUpdateVersion: Version before update attempt
// - TargetVersion: Version attempting to update to
//
// This enables verification of successful updates and appropriate user notification
// after the client restarts with the new version.
package updatemanager

View File

@@ -0,0 +1,138 @@
package downloader
import (
"context"
"fmt"
"io"
"net/http"
"os"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
)
const (
userAgent = "NetBird agent installer/%s"
DefaultRetryDelay = 3 * time.Second
)
func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error {
log.Debugf("starting download from %s", url)
out, err := os.Create(dstFile)
if err != nil {
return fmt.Errorf("failed to create destination file %q: %w", dstFile, err)
}
defer func() {
if cerr := out.Close(); cerr != nil {
log.Warnf("error closing file %q: %v", dstFile, cerr)
}
}()
// First attempt
err = downloadToFileOnce(ctx, url, out)
if err == nil {
log.Infof("successfully downloaded file to %s", dstFile)
return nil
}
// If retryDelay is 0, don't retry
if retryDelay == 0 {
return err
}
log.Warnf("download failed, retrying after %v: %v", retryDelay, err)
// Sleep before retry
if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil {
return fmt.Errorf("download cancelled during retry delay: %w", sleepErr)
}
// Truncate file before retry
if err := out.Truncate(0); err != nil {
return fmt.Errorf("failed to truncate file on retry: %w", err)
}
if _, err := out.Seek(0, 0); err != nil {
return fmt.Errorf("failed to seek to beginning of file: %w", err)
}
// Second attempt
if err := downloadToFileOnce(ctx, url, out); err != nil {
return fmt.Errorf("download failed after retry: %w", err)
}
log.Infof("successfully downloaded file to %s", dstFile)
return nil
}
func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
// Add User-Agent header
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to perform HTTP request: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
log.Warnf("error closing response body: %v", cerr)
}
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
}
data, err := io.ReadAll(io.LimitReader(resp.Body, limit))
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return data, nil
}
func downloadToFileOnce(ctx context.Context, url string, out *os.File) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("failed to create HTTP request: %w", err)
}
// Add User-Agent header
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to perform HTTP request: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
log.Warnf("error closing response body: %v", cerr)
}
}()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
}
if _, err := io.Copy(out, resp.Body); err != nil {
return fmt.Errorf("failed to write response body to file: %w", err)
}
return nil
}
func sleepWithContext(ctx context.Context, duration time.Duration) error {
select {
case <-time.After(duration):
return nil
case <-ctx.Done():
return ctx.Err()
}
}

View File

@@ -0,0 +1,199 @@
package downloader
import (
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
)
const (
retryDelay = 100 * time.Millisecond
)
func TestDownloadToFile_Success(t *testing.T) {
// Create a test server that responds successfully
content := "test file content"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(content))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file
err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
// Verify the file content
data, err := os.ReadFile(dstFile)
if err != nil {
t.Fatalf("failed to read downloaded file: %v", err)
}
if string(data) != content {
t.Errorf("expected content %q, got %q", content, string(data))
}
}
func TestDownloadToFile_SuccessAfterRetry(t *testing.T) {
content := "test file content after retry"
var attemptCount atomic.Int32
// Create a test server that fails on first attempt, succeeds on second
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempt := attemptCount.Add(1)
if attempt == 1 {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(content))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file (should succeed after retry)
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil {
t.Fatalf("expected no error after retry, got: %v", err)
}
// Verify the file content
data, err := os.ReadFile(dstFile)
if err != nil {
t.Fatalf("failed to read downloaded file: %v", err)
}
if string(data) != content {
t.Errorf("expected content %q, got %q", content, string(data))
}
// Verify it took 2 attempts
if attemptCount.Load() != 2 {
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_FailsAfterRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file (should fail after retry)
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil {
t.Fatal("expected error after retry, got nil")
}
// Verify it tried 2 times
if attemptCount.Load() != 2 {
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Create a context that will be cancelled during retry delay
ctx, cancel := context.WithCancel(context.Background())
// Cancel after a short delay (during the retry sleep)
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
// Download the file (should fail due to context cancellation during retry)
err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile)
if err == nil {
t.Fatal("expected error due to context cancellation, got nil")
}
// Should have only made 1 attempt (cancelled during retry delay)
if attemptCount.Load() != 1 {
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
}
}
func TestDownloadToFile_InvalidURL(t *testing.T) {
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile)
if err == nil {
t.Fatal("expected error for invalid URL, got nil")
}
}
func TestDownloadToFile_InvalidDestination(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("test"))
}))
defer server.Close()
// Use an invalid destination path
err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt")
if err == nil {
t.Fatal("expected error for invalid destination, got nil")
}
}
func TestDownloadToFile_NoRetry(t *testing.T) {
var attemptCount atomic.Int32
// Create a test server that always fails
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("error"))
}))
defer server.Close()
// Create a temporary file for download
tempDir := t.TempDir()
dstFile := filepath.Join(tempDir, "downloaded.txt")
// Download the file with retryDelay = 0 (should not retry)
if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil {
t.Fatal("expected error, got nil")
}
// Verify it only made 1 attempt (no retry)
if attemptCount.Load() != 1 {
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
}
}

View File

@@ -0,0 +1,7 @@
//go:build !windows
package installer
func UpdaterBinaryNameWithoutExtension() string {
return updaterBinary
}

View File

@@ -0,0 +1,11 @@
package installer
import (
"path/filepath"
"strings"
)
func UpdaterBinaryNameWithoutExtension() string {
ext := filepath.Ext(updaterBinary)
return strings.TrimSuffix(updaterBinary, ext)
}

View File

@@ -0,0 +1,111 @@
// Package installer provides functionality for managing NetBird application
// updates and installations across Windows, macOS. It handles
// the complete update lifecycle including artifact download, cryptographic verification,
// installation execution, process management, and result reporting.
//
// # Architecture
//
// The installer package uses a two-process architecture to enable self-updates:
//
// 1. Service Process: The main NetBird daemon process that initiates updates
// 2. Updater Process: A detached child process that performs the actual installation
//
// This separation is critical because:
// - The service binary cannot update itself while running
// - The installer (EXE/MSI/PKG) will terminate the service during installation
// - The updater process survives service termination and restarts it after installation
// - Results can be communicated back to the service after it restarts
//
// # Update Flow
//
// Service Process (RunInstallation):
//
// 1. Validates target version format (semver)
// 2. Determines installer type (EXE, MSI, PKG, or Homebrew)
// 3. Downloads installer file from GitHub releases (if applicable)
// 4. Verifies installer signature using reposign package (cryptographic verification in service process before
// launching updater)
// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows)
// 6. Launches updater process with detached mode:
// - --temp-dir: Temporary directory path
// - --service-dir: Service installation directory
// - --installer-file: Path to downloaded installer (if applicable)
// - --dry-run: Optional flag to test without actually installing
// 7. Service process continues running (will be terminated by installer later)
// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion
//
// Updater Process (Setup):
//
// 1. Receives parameters from service via command-line arguments
// 2. Runs installer with appropriate silent/quiet flags:
// - Windows EXE: installer.exe /S
// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log
// - macOS PKG: installer -pkg installer.pkg -target /
// - macOS Homebrew: brew upgrade netbirdio/tap/netbird
// 3. Installer terminates daemon and UI processes
// 4. Installer replaces binaries with new version
// 5. Updater waits for installer to complete
// 6. Updater restarts daemon:
// - Windows: netbird.exe service start
// - macOS/Linux: netbird service start
// 7. Updater restarts UI:
// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser
// - macOS: Uses launchctl asuser to launch NetBird.app for console user
// - Linux: Not implemented (UI typically auto-starts)
// 8. Updater writes result.json with success/error status
// 9. Updater process exits
//
// # Result Communication
//
// The ResultHandler (result.go) manages communication between updater and service:
//
// Result Structure:
//
// type Result struct {
// Success bool // true if installation succeeded
// Error string // error message if Success is false
// ExecutedAt time.Time // when installation completed
// }
//
// Result files are automatically cleaned up after being read.
//
// # File Locations
//
// Temporary Directory (platform-specific):
//
// Windows:
// - Path: %ProgramData%\Netbird\tmp-install
// - Example: C:\ProgramData\Netbird\tmp-install
//
// macOS:
// - Path: /var/lib/netbird/tmp-install
// - Requires root permissions
//
// Files created during installation:
//
// tmp-install/
// installer.log
// updater[.exe] # Copy of service binary
// netbird_installer_*.[exe|msi|pkg] # Downloaded installer
// result.json # Installation result
// msi.log # MSI verbose log (Windows MSI only)
//
// # API Reference
//
// # Cleanup
//
// CleanUpInstallerFiles() removes temporary files after successful installation:
// - Downloaded installer files (*.exe, *.msi, *.pkg)
// - Updater binary copy
// - Does NOT remove result.json (cleaned by ResultHandler after read)
// - Does NOT remove msi.log (kept for debugging)
//
// # Dry-Run Mode
//
// Dry-run mode allows testing the update process without actually installing:
//
// Enable via environment variable:
//
// export NB_AUTO_UPDATE_DRY_RUN=true
// netbird service install-update 0.29.0
package installer

View File

@@ -0,0 +1,50 @@
//go:build !windows && !darwin
package installer
import (
"context"
"fmt"
)
const (
updaterBinary = "updater"
)
type Installer struct {
tempDir string
}
// New used by the service
func New() *Installer {
return &Installer{}
}
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
func NewWithDir(tempDir string) *Installer {
return &Installer{
tempDir: tempDir,
}
}
func (u *Installer) TempDir() string {
return ""
}
func (c *Installer) LogFiles() []string {
return []string{}
}
func (u *Installer) CleanUpInstallerFiles() error {
return nil
}
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error {
return fmt.Errorf("unsupported platform")
}
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) {
return fmt.Errorf("unsupported platform")
}

View File

@@ -0,0 +1,293 @@
//go:build windows || darwin
package installer
import (
"context"
"fmt"
"io"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"github.com/hashicorp/go-multierror"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
)
type Installer struct {
tempDir string
}
// New used by the service
func New() *Installer {
return &Installer{
tempDir: defaultTempDir,
}
}
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
func NewWithDir(tempDir string) *Installer {
return &Installer{
tempDir: tempDir,
}
}
// RunInstallation starts the updater process to run the installation
// This will run by the original service process
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) {
resultHandler := NewResultHandler(u.tempDir)
defer func() {
if err != nil {
if writeErr := resultHandler.WriteErr(err); writeErr != nil {
log.Errorf("failed to write error result: %v", writeErr)
}
}
}()
if err := validateTargetVersion(targetVersion); err != nil {
return err
}
if err := u.mkTempDir(); err != nil {
return err
}
var installerFile string
// Download files only when not using any third-party store
if installerType := TypeOfInstaller(ctx); installerType.Downloadable() {
log.Infof("download installer")
var err error
installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion)
if err != nil {
log.Errorf("failed to download installer: %v", err)
return err
}
artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL)
if err != nil {
log.Errorf("failed to create artifact verify: %v", err)
return err
}
if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil {
log.Errorf("artifact verification error: %v", err)
return err
}
}
log.Infof("running installer")
updaterPath, err := u.copyUpdater()
if err != nil {
return err
}
// the directory where the service has been installed
workspace, err := getServiceDir()
if err != nil {
return err
}
args := []string{
"--temp-dir", u.tempDir,
"--service-dir", workspace,
}
if isDryRunEnabled() {
args = append(args, "--dry-run=true")
}
if installerFile != "" {
args = append(args, "--installer-file", installerFile)
}
updateCmd := exec.Command(updaterPath, args...)
log.Infof("starting updater process: %s", updateCmd.String())
// Configure the updater to run in a separate session/process group
// so it survives the parent daemon being stopped
setUpdaterProcAttr(updateCmd)
// Start the updater process asynchronously
if err := updateCmd.Start(); err != nil {
return err
}
pid := updateCmd.Process.Pid
log.Infof("updater started with PID %d", pid)
// Release the process so the OS can fully detach it
if err := updateCmd.Process.Release(); err != nil {
log.Warnf("failed to release updater process: %v", err)
}
return nil
}
// CleanUpInstallerFiles
// - the installer file (pkg, exe, msi)
// - the selfcopy updater.exe
func (u *Installer) CleanUpInstallerFiles() error {
// Check if tempDir exists
info, err := os.Stat(u.tempDir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
if !info.IsDir() {
return nil
}
var merr *multierror.Error
if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) {
merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err))
}
entries, err := os.ReadDir(u.tempDir)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
for _, ext := range binaryExtensions {
if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) {
if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err))
}
break
}
}
}
return merr.ErrorOrNil()
}
func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) {
fileURL := urlWithVersionArch(installerType, targetVersion)
// Clean up temp directory on error
var success bool
defer func() {
if !success {
if err := os.RemoveAll(u.tempDir); err != nil {
log.Errorf("error cleaning up temporary directory: %v", err)
}
}
}()
fileName := path.Base(fileURL)
if fileName == "." || fileName == "/" || fileName == "" {
return "", fmt.Errorf("invalid file URL: %s", fileURL)
}
outputFilePath := filepath.Join(u.tempDir, fileName)
if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil {
return "", err
}
success = true
return outputFilePath, nil
}
func (u *Installer) TempDir() string {
return u.tempDir
}
func (u *Installer) mkTempDir() error {
if err := os.MkdirAll(u.tempDir, 0o755); err != nil {
log.Debugf("failed to create tempdir: %s", u.tempDir)
return err
}
return nil
}
func (u *Installer) copyUpdater() (string, error) {
src, err := getServiceBinary()
if err != nil {
return "", fmt.Errorf("failed to get updater binary: %w", err)
}
dst := filepath.Join(u.tempDir, updaterBinary)
if err := copyFile(src, dst); err != nil {
return "", fmt.Errorf("failed to copy updater binary: %w", err)
}
if err := os.Chmod(dst, 0o755); err != nil {
return "", fmt.Errorf("failed to set permissions: %w", err)
}
return dst, nil
}
func validateTargetVersion(targetVersion string) error {
if targetVersion == "" {
return fmt.Errorf("target version cannot be empty")
}
_, err := goversion.NewVersion(targetVersion)
if err != nil {
return fmt.Errorf("invalid target version %q: %w", targetVersion, err)
}
return nil
}
func copyFile(src, dst string) error {
log.Infof("copying %s to %s", src, dst)
in, err := os.Open(src)
if err != nil {
return fmt.Errorf("open source: %w", err)
}
defer func() {
if err := in.Close(); err != nil {
log.Warnf("failed to close source file: %v", err)
}
}()
out, err := os.Create(dst)
if err != nil {
return fmt.Errorf("create destination: %w", err)
}
defer func() {
if err := out.Close(); err != nil {
log.Warnf("failed to close destination file: %v", err)
}
}()
if _, err := io.Copy(out, in); err != nil {
return fmt.Errorf("copy: %w", err)
}
return nil
}
func getServiceDir() (string, error) {
exePath, err := os.Executable()
if err != nil {
return "", err
}
return filepath.Dir(exePath), nil
}
func getServiceBinary() (string, error) {
return os.Executable()
}
func isDryRunEnabled() bool {
return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true")
}

View File

@@ -0,0 +1,11 @@
package installer
import (
"path/filepath"
)
func (u *Installer) LogFiles() []string {
return []string{
filepath.Join(u.tempDir, LogFile),
}
}

View File

@@ -0,0 +1,12 @@
package installer
import (
"path/filepath"
)
func (u *Installer) LogFiles() []string {
return []string{
filepath.Join(u.tempDir, msiLogFile),
filepath.Join(u.tempDir, LogFile),
}
}

View File

@@ -0,0 +1,238 @@
package installer
import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
log "github.com/sirupsen/logrus"
)
const (
daemonName = "netbird"
updaterBinary = "updater"
uiBinary = "/Applications/NetBird.app"
defaultTempDir = "/var/lib/netbird/tmp-install"
pkgDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
)
var (
binaryExtensions = []string{"pkg"}
)
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
resultHandler := NewResultHandler(u.tempDir)
// Always ensure daemon and UI are restarted after setup
defer func() {
log.Infof("write out result")
var err error
if resultErr == nil {
err = resultHandler.WriteSuccess()
} else {
err = resultHandler.WriteErr(resultErr)
}
if err != nil {
log.Errorf("failed to write update result: %v", err)
}
// skip service restart if dry-run mode is enabled
if dryRun {
return
}
log.Infof("starting daemon back")
if err := u.startDaemon(daemonFolder); err != nil {
log.Errorf("failed to start daemon: %v", err)
}
log.Infof("starting UI back")
if err := u.startUIAsUser(); err != nil {
log.Errorf("failed to start UI: %v", err)
}
}()
if dryRun {
time.Sleep(7 * time.Second)
log.Infof("dry-run mode enabled, skipping actual installation")
resultErr = fmt.Errorf("dry-run mode enabled")
return
}
switch TypeOfInstaller(ctx) {
case TypePKG:
resultErr = u.installPkgFile(ctx, installerFile)
case TypeHomebrew:
resultErr = u.updateHomeBrew(ctx)
}
return resultErr
}
func (u *Installer) startDaemon(daemonFolder string) error {
log.Infof("starting netbird service")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
if output, err := cmd.CombinedOutput(); err != nil {
log.Warnf("failed to start netbird service: %v, output: %s", err, string(output))
return err
}
log.Infof("netbird service started successfully")
return nil
}
func (u *Installer) startUIAsUser() error {
log.Infof("starting netbird-ui: %s", uiBinary)
// Get the current console user
cmd := exec.Command("stat", "-f", "%Su", "/dev/console")
output, err := cmd.Output()
if err != nil {
return fmt.Errorf("failed to get console user: %w", err)
}
username := strings.TrimSpace(string(output))
if username == "" || username == "root" {
return fmt.Errorf("no active user session found")
}
log.Infof("starting UI for user: %s", username)
// Get user's UID
userInfo, err := user.Lookup(username)
if err != nil {
return fmt.Errorf("failed to lookup user %s: %w", username, err)
}
// Start the UI process as the console user using launchctl
// This ensures the app runs in the user's context with proper GUI access
launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary)
log.Infof("launchCmd: %s", launchCmd.String())
// Set the user's home directory for proper macOS app behavior
launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir)
log.Infof("set HOME environment variable: %s", userInfo.HomeDir)
if err := launchCmd.Start(); err != nil {
return fmt.Errorf("failed to start UI process: %w", err)
}
// Release the process so it can run independently
if err := launchCmd.Process.Release(); err != nil {
log.Warnf("failed to release UI process: %v", err)
}
log.Infof("netbird-ui started successfully for user %s", username)
return nil
}
func (u *Installer) installPkgFile(ctx context.Context, path string) error {
log.Infof("installing pkg file: %s", path)
// Kill any existing UI processes before installation
// This ensures the postinstall script's "open $APP" will start the new version
u.killUI()
volume := "/"
cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
if err := cmd.Start(); err != nil {
return fmt.Errorf("error running pkg file: %w", err)
}
log.Infof("installer started with PID %d", cmd.Process.Pid)
if err := cmd.Wait(); err != nil {
return fmt.Errorf("error running pkg file: %w", err)
}
log.Infof("pkg file installed successfully")
return nil
}
func (u *Installer) updateHomeBrew(ctx context.Context) error {
log.Infof("updating homebrew")
// Kill any existing UI processes before upgrade
// This ensures the new version will be started after upgrade
u.killUI()
// Homebrew must be run as a non-root user
// To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
// Check both Apple Silicon and Intel Mac paths
brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/"
brewBinPath := "/opt/homebrew/bin/brew"
if _, err := os.Stat(brewTapPath); os.IsNotExist(err) {
// Try Intel Mac path
brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/"
brewBinPath = "/usr/local/bin/brew"
}
fileInfo, err := os.Stat(brewTapPath)
if err != nil {
return fmt.Errorf("error getting homebrew installation path info: %w", err)
}
fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
if !ok {
return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
}
// Get username from UID
brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
if err != nil {
return fmt.Errorf("error looking up brew installer user: %w", err)
}
userName := brewUser.Username
// Get user HOME, required for brew to run correctly
// https://github.com/Homebrew/brew/issues/15833
homeDir := brewUser.HomeDir
// Check if netbird-ui is installed (must run as the brew user, not root)
checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui")
checkUICmd.Env = append(os.Environ(), "HOME="+homeDir)
uiInstalled := checkUICmd.Run() == nil
// Homebrew does not support installing specific versions
// Thus it will always update to latest and ignore targetVersion
upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"}
if uiInstalled {
upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
}
cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...)
cmd.Env = append(os.Environ(), "HOME="+homeDir)
if output, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output))
}
log.Infof("homebrew updated successfully")
return nil
}
func (u *Installer) killUI() {
log.Infof("killing existing netbird-ui processes")
cmd := exec.Command("pkill", "-x", "netbird-ui")
if output, err := cmd.CombinedOutput(); err != nil {
// pkill returns exit code 1 if no processes matched, which is fine
log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output))
} else {
log.Infof("netbird-ui processes killed")
}
}
func urlWithVersionArch(_ Type, version string) string {
url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
}

View File

@@ -0,0 +1,213 @@
package installer
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
daemonName = "netbird.exe"
uiName = "netbird-ui.exe"
updaterBinary = "updater.exe"
msiLogFile = "msi.log"
msiDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
exeDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
)
var (
defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install")
// for the cleanup
binaryExtensions = []string{"msi", "exe"}
)
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
// This will be run by the updater process
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
resultHandler := NewResultHandler(u.tempDir)
// Always ensure daemon and UI are restarted after setup
defer func() {
log.Infof("starting daemon back")
if err := u.startDaemon(daemonFolder); err != nil {
log.Errorf("failed to start daemon: %v", err)
}
log.Infof("starting UI back")
if err := u.startUIAsUser(daemonFolder); err != nil {
log.Errorf("failed to start UI: %v", err)
}
log.Infof("write out result")
var err error
if resultErr == nil {
err = resultHandler.WriteSuccess()
} else {
err = resultHandler.WriteErr(resultErr)
}
if err != nil {
log.Errorf("failed to write update result: %v", err)
}
}()
if dryRun {
log.Infof("dry-run mode enabled, skipping actual installation")
resultErr = fmt.Errorf("dry-run mode enabled")
return
}
installerType, err := typeByFileExtension(installerFile)
if err != nil {
log.Debugf("%v", err)
resultErr = err
return
}
var cmd *exec.Cmd
switch installerType {
case TypeExe:
log.Infof("run exe installer: %s", installerFile)
cmd = exec.CommandContext(ctx, installerFile, "/S")
default:
installerDir := filepath.Dir(installerFile)
logPath := filepath.Join(installerDir, msiLogFile)
log.Infof("run msi installer: %s", installerFile)
cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath)
}
cmd.Dir = filepath.Dir(installerFile)
if resultErr = cmd.Start(); resultErr != nil {
log.Errorf("error starting installer: %v", resultErr)
return
}
log.Infof("installer started with PID %d", cmd.Process.Pid)
if resultErr = cmd.Wait(); resultErr != nil {
log.Errorf("installer process finished with error: %v", resultErr)
return
}
return nil
}
func (u *Installer) startDaemon(daemonFolder string) error {
log.Infof("starting netbird service")
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
if output, err := cmd.CombinedOutput(); err != nil {
log.Debugf("failed to start netbird service: %v, output: %s", err, string(output))
return err
}
log.Infof("netbird service started successfully")
return nil
}
func (u *Installer) startUIAsUser(daemonFolder string) error {
uiPath := filepath.Join(daemonFolder, uiName)
log.Infof("starting netbird-ui: %s", uiPath)
// Get the active console session ID
sessionID := windows.WTSGetActiveConsoleSessionId()
if sessionID == 0xFFFFFFFF {
return fmt.Errorf("no active user session found")
}
// Get the user token for that session
var userToken windows.Token
err := windows.WTSQueryUserToken(sessionID, &userToken)
if err != nil {
return fmt.Errorf("failed to query user token: %w", err)
}
defer func() {
if err := userToken.Close(); err != nil {
log.Warnf("failed to close user token: %v", err)
}
}()
// Duplicate the token to a primary token
var primaryToken windows.Token
err = windows.DuplicateTokenEx(
userToken,
windows.MAXIMUM_ALLOWED,
nil,
windows.SecurityImpersonation,
windows.TokenPrimary,
&primaryToken,
)
if err != nil {
return fmt.Errorf("failed to duplicate token: %w", err)
}
defer func() {
if err := primaryToken.Close(); err != nil {
log.Warnf("failed to close token: %v", err)
}
}()
// Prepare startup info
var si windows.StartupInfo
si.Cb = uint32(unsafe.Sizeof(si))
si.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
var pi windows.ProcessInformation
cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath))
if err != nil {
return fmt.Errorf("failed to convert path to UTF16: %w", err)
}
creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT
err = windows.CreateProcessAsUser(
primaryToken,
nil,
cmdLine,
nil,
nil,
false,
creationFlags,
nil,
nil,
&si,
&pi,
)
if err != nil {
return fmt.Errorf("CreateProcessAsUser failed: %w", err)
}
// Close handles
if err := windows.CloseHandle(pi.Process); err != nil {
log.Warnf("failed to close process handle: %v", err)
}
if err := windows.CloseHandle(pi.Thread); err != nil {
log.Warnf("failed to close thread handle: %v", err)
}
log.Infof("netbird-ui started successfully in session %d", sessionID)
return nil
}
func urlWithVersionArch(it Type, version string) string {
var url string
if it == TypeExe {
url = exeDownloadURL
} else {
url = msiDownloadURL
}
url = strings.ReplaceAll(url, "%version", version)
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
}

View File

@@ -0,0 +1,5 @@
package installer
const (
LogFile = "installer.log"
)

View File

@@ -0,0 +1,15 @@
package installer
import (
"os/exec"
"syscall"
)
// setUpdaterProcAttr configures the updater process to run in a new session,
// making it independent of the parent daemon process. This ensures the updater
// survives when the daemon is stopped during the pkg installation.
func setUpdaterProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
}
}

View File

@@ -0,0 +1,14 @@
package installer
import (
"os/exec"
"syscall"
)
// setUpdaterProcAttr configures the updater process to run detached from the parent,
// making it independent of the parent daemon process.
func setUpdaterProcAttr(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS
}
}

View File

@@ -0,0 +1,7 @@
//go:build devartifactsign
package installer
const (
DefaultSigningKeysBaseURL = "http://192.168.0.10:9089/signrepo"
)

View File

@@ -0,0 +1,7 @@
//go:build !devartifactsign
package installer
const (
DefaultSigningKeysBaseURL = "https://publickeys.netbird.io/artifact-signatures"
)

View File

@@ -0,0 +1,230 @@
package installer
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus"
)
const (
resultFile = "result.json"
)
type Result struct {
Success bool
Error string
ExecutedAt time.Time
}
// ResultHandler handles reading and writing update results
type ResultHandler struct {
resultFile string
}
// NewResultHandler creates a new communicator with the given directory path
// The result file will be created as "result.json" in the specified directory
func NewResultHandler(installerDir string) *ResultHandler {
// Create it if it doesn't exist
// do not care if already exists
_ = os.MkdirAll(installerDir, 0o700)
rh := &ResultHandler{
resultFile: filepath.Join(installerDir, resultFile),
}
return rh
}
func (rh *ResultHandler) GetErrorResultReason() string {
result, err := rh.tryReadResult()
if err == nil && !result.Success {
return result.Error
}
if err := rh.cleanup(); err != nil {
log.Warnf("failed to cleanup result file: %v", err)
}
return ""
}
func (rh *ResultHandler) WriteSuccess() error {
result := Result{
Success: true,
ExecutedAt: time.Now(),
}
return rh.write(result)
}
func (rh *ResultHandler) WriteErr(errReason error) error {
result := Result{
Success: false,
Error: errReason.Error(),
ExecutedAt: time.Now(),
}
return rh.write(result)
}
func (rh *ResultHandler) Watch(ctx context.Context) (Result, error) {
log.Infof("start watching result: %s", rh.resultFile)
// Check if file already exists (updater finished before we started watching)
if result, err := rh.tryReadResult(); err == nil {
log.Infof("installer result: %v", result)
return result, nil
}
dir := filepath.Dir(rh.resultFile)
if err := rh.waitForDirectory(ctx, dir); err != nil {
return Result{}, err
}
return rh.watchForResultFile(ctx, dir)
}
func (rh *ResultHandler) waitForDirectory(ctx context.Context, dir string) error {
ticker := time.NewTicker(300 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if info, err := os.Stat(dir); err == nil && info.IsDir() {
return nil
}
}
}
}
func (rh *ResultHandler) watchForResultFile(ctx context.Context, dir string) (Result, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Error(err)
return Result{}, err
}
defer func() {
if err := watcher.Close(); err != nil {
log.Warnf("failed to close watcher: %v", err)
}
}()
if err := watcher.Add(dir); err != nil {
return Result{}, fmt.Errorf("failed to watch directory: %v", err)
}
// Check again after setting up watcher to avoid race condition
// (file could have been created between initial check and watcher setup)
if result, err := rh.tryReadResult(); err == nil {
log.Infof("installer result: %v", result)
return result, nil
}
for {
select {
case <-ctx.Done():
return Result{}, ctx.Err()
case event, ok := <-watcher.Events:
if !ok {
return Result{}, errors.New("watcher closed unexpectedly")
}
if result, done := rh.handleWatchEvent(event); done {
return result, nil
}
case err, ok := <-watcher.Errors:
if !ok {
return Result{}, errors.New("watcher closed unexpectedly")
}
return Result{}, fmt.Errorf("watcher error: %w", err)
}
}
}
func (rh *ResultHandler) handleWatchEvent(event fsnotify.Event) (Result, bool) {
if event.Name != rh.resultFile {
return Result{}, false
}
if event.Has(fsnotify.Create) {
result, err := rh.tryReadResult()
if err != nil {
log.Debugf("error while reading result: %v", err)
return result, true
}
log.Infof("installer result: %v", result)
return result, true
}
return Result{}, false
}
// Write writes the update result to a file for the UI to read
func (rh *ResultHandler) write(result Result) error {
log.Infof("write out installer result to: %s", rh.resultFile)
// Ensure directory exists
dir := filepath.Dir(rh.resultFile)
if err := os.MkdirAll(dir, 0o755); err != nil {
log.Errorf("failed to create directory %s: %v", dir, err)
return err
}
data, err := json.Marshal(result)
if err != nil {
return err
}
// Write to a temporary file first, then rename for atomic operation
tmpPath := rh.resultFile + ".tmp"
if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
log.Errorf("failed to create temp file: %s", err)
return err
}
// Atomic rename
if err := os.Rename(tmpPath, rh.resultFile); err != nil {
if cleanupErr := os.Remove(tmpPath); cleanupErr != nil {
log.Warnf("Failed to remove temp result file: %v", err)
}
return err
}
return nil
}
func (rh *ResultHandler) cleanup() error {
err := os.Remove(rh.resultFile)
if err != nil && !os.IsNotExist(err) {
return err
}
log.Debugf("delete installer result file: %s", rh.resultFile)
return nil
}
// tryReadResult attempts to read and validate the result file
func (rh *ResultHandler) tryReadResult() (Result, error) {
data, err := os.ReadFile(rh.resultFile)
if err != nil {
return Result{}, err
}
var result Result
if err := json.Unmarshal(data, &result); err != nil {
return Result{}, fmt.Errorf("invalid result format: %w", err)
}
if err := rh.cleanup(); err != nil {
log.Warnf("failed to cleanup result file: %v", err)
}
return result, nil
}

View File

@@ -0,0 +1,14 @@
package installer
type Type struct {
name string
downloadable bool
}
func (t Type) String() string {
return t.name
}
func (t Type) Downloadable() bool {
return t.downloadable
}

View File

@@ -0,0 +1,22 @@
package installer
import (
"context"
"os/exec"
)
var (
TypeHomebrew = Type{name: "Homebrew", downloadable: false}
TypePKG = Type{name: "pkg", downloadable: true}
)
func TypeOfInstaller(ctx context.Context) Type {
cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client")
_, err := cmd.Output()
if err != nil && cmd.ProcessState.ExitCode() == 1 {
// Not installed using pkg file, thus installed using Homebrew
return TypeHomebrew
}
return TypePKG
}

View File

@@ -0,0 +1,51 @@
package installer
import (
"context"
"fmt"
"strings"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
)
const (
uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
)
var (
TypeExe = Type{name: "EXE", downloadable: true}
TypeMSI = Type{name: "MSI", downloadable: true}
)
func TypeOfInstaller(_ context.Context) Type {
paths := []string{uninstallKeyPath64, uninstallKeyPath32}
for _, path := range paths {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
if err != nil {
continue
}
if err := k.Close(); err != nil {
log.Warnf("Error closing registry key: %v", err)
}
return TypeExe
}
log.Debug("No registry entry found for Netbird, assuming MSI installation")
return TypeMSI
}
func typeByFileExtension(filePath string) (Type, error) {
switch {
case strings.HasSuffix(strings.ToLower(filePath), ".exe"):
return TypeExe, nil
case strings.HasSuffix(strings.ToLower(filePath), ".msi"):
return TypeMSI, nil
default:
return Type{}, fmt.Errorf("unsupported installer type for file: %s", filePath)
}
}

View File

@@ -0,0 +1,374 @@
//go:build windows || darwin
package updatemanager
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
v "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
const (
latestVersion = "latest"
// this version will be ignored
developmentVersion = "development"
)
var errNoUpdateState = errors.New("no update state found")
type UpdateState struct {
PreUpdateVersion string
TargetVersion string
}
func (u UpdateState) Name() string {
return "autoUpdate"
}
type Manager struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
lastTrigger time.Time
mgmUpdateChan chan struct{}
updateChannel chan struct{}
currentVersion string
update UpdateInterface
wg sync.WaitGroup
cancel context.CancelFunc
expectedVersion *v.Version
updateToLatestVersion bool
// updateMutex protect update and expectedVersion fields
updateMutex sync.Mutex
triggerUpdateFn func(context.Context, string) error
}
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
if runtime.GOOS == "darwin" {
isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
if isBrew {
log.Warnf("auto-update disabled on Home Brew installation")
return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
}
}
return newManager(statusRecorder, stateManager)
}
func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
manager := &Manager{
statusRecorder: statusRecorder,
stateManager: stateManager,
mgmUpdateChan: make(chan struct{}, 1),
updateChannel: make(chan struct{}, 1),
currentVersion: version.NetbirdVersion(),
update: version.NewUpdate("nb/client"),
}
manager.triggerUpdateFn = manager.triggerUpdate
stateManager.RegisterState(&UpdateState{})
return manager, nil
}
// CheckUpdateSuccess checks if the update was successful and send a notification.
// It works without to start the update manager.
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
reason := m.lastResultErrReason()
if reason != "" {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
"Auto-update failed",
fmt.Sprintf("Auto-update failed: %s", reason),
nil,
)
}
updateState, err := m.loadAndDeleteUpdateState(ctx)
if err != nil {
if errors.Is(err, errNoUpdateState) {
return
}
log.Errorf("failed to load update state: %v", err)
return
}
log.Debugf("auto-update state loaded, %v", *updateState)
if updateState.TargetVersion == m.currentVersion {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"Auto-update completed",
fmt.Sprintf("Your NetBird Client was auto-updated to version %s", m.currentVersion),
nil,
)
return
}
}
func (m *Manager) Start(ctx context.Context) {
if m.cancel != nil {
log.Errorf("Manager already started")
return
}
m.update.SetDaemonVersion(m.currentVersion)
m.update.SetOnUpdateListener(func() {
select {
case m.updateChannel <- struct{}{}:
default:
}
})
go m.update.StartFetcher()
ctx, cancel := context.WithCancel(ctx)
m.cancel = cancel
m.wg.Add(1)
go m.updateLoop(ctx)
}
func (m *Manager) SetVersion(expectedVersion string) {
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
if m.cancel == nil {
log.Errorf("manager not started")
return
}
m.updateMutex.Lock()
defer m.updateMutex.Unlock()
if expectedVersion == "" {
log.Errorf("empty expected version provided")
m.expectedVersion = nil
m.updateToLatestVersion = false
return
}
if expectedVersion == latestVersion {
m.updateToLatestVersion = true
m.expectedVersion = nil
} else {
expectedSemVer, err := v.NewVersion(expectedVersion)
if err != nil {
log.Errorf("error parsing version: %v", err)
return
}
if m.expectedVersion != nil && m.expectedVersion.Equal(expectedSemVer) {
return
}
m.expectedVersion = expectedSemVer
m.updateToLatestVersion = false
}
select {
case m.mgmUpdateChan <- struct{}{}:
default:
}
}
func (m *Manager) Stop() {
if m.cancel == nil {
return
}
m.cancel()
m.updateMutex.Lock()
if m.update != nil {
m.update.StopWatch()
m.update = nil
}
m.updateMutex.Unlock()
m.wg.Wait()
}
func (m *Manager) onContextCancel() {
if m.cancel == nil {
return
}
m.updateMutex.Lock()
defer m.updateMutex.Unlock()
if m.update != nil {
m.update.StopWatch()
m.update = nil
}
}
func (m *Manager) updateLoop(ctx context.Context) {
defer m.wg.Done()
for {
select {
case <-ctx.Done():
m.onContextCancel()
return
case <-m.mgmUpdateChan:
case <-m.updateChannel:
log.Infof("fetched new version info")
}
m.handleUpdate(ctx)
}
}
func (m *Manager) handleUpdate(ctx context.Context) {
var updateVersion *v.Version
m.updateMutex.Lock()
if m.update == nil {
m.updateMutex.Unlock()
return
}
expectedVersion := m.expectedVersion
useLatest := m.updateToLatestVersion
curLatestVersion := m.update.LatestVersion()
m.updateMutex.Unlock()
switch {
// Resolve "latest" to actual version
case useLatest:
if curLatestVersion == nil {
log.Tracef("latest version not fetched yet")
return
}
updateVersion = curLatestVersion
// Update to specific version
case expectedVersion != nil:
updateVersion = expectedVersion
default:
log.Debugf("no expected version information set")
return
}
log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
if !m.shouldUpdate(updateVersion) {
return
}
m.lastTrigger = time.Now()
log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"Automatically updating client",
"Your client version is older than auto-update version set in Management, updating client now.",
nil,
)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"",
"",
map[string]string{"progress_window": "show", "version": updateVersion.String()},
)
updateState := UpdateState{
PreUpdateVersion: m.currentVersion,
TargetVersion: updateVersion.String(),
}
if err := m.stateManager.UpdateState(updateState); err != nil {
log.Warnf("failed to update state: %v", err)
} else {
if err = m.stateManager.PersistState(ctx); err != nil {
log.Warnf("failed to persist state: %v", err)
}
}
if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
log.Errorf("Error triggering auto-update: %v", err)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
"Auto-update failed",
fmt.Sprintf("Auto-update failed: %v", err),
nil,
)
}
}
// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
// Returns nil if no state exists.
func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, error) {
stateType := &UpdateState{}
m.stateManager.RegisterState(stateType)
if err := m.stateManager.LoadState(stateType); err != nil {
return nil, fmt.Errorf("load state: %w", err)
}
state := m.stateManager.GetState(stateType)
if state == nil {
return nil, errNoUpdateState
}
updateState, ok := state.(*UpdateState)
if !ok {
return nil, fmt.Errorf("failed to cast state to UpdateState")
}
if err := m.stateManager.DeleteState(updateState); err != nil {
return nil, fmt.Errorf("delete state: %w", err)
}
if err := m.stateManager.PersistState(ctx); err != nil {
return nil, fmt.Errorf("persist state: %w", err)
}
return updateState, nil
}
func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
if m.currentVersion == developmentVersion {
log.Debugf("skipping auto-update, running development version")
return false
}
currentVersion, err := v.NewVersion(m.currentVersion)
if err != nil {
log.Errorf("error checking for update, error parsing version `%s`: %v", m.currentVersion, err)
return false
}
if currentVersion.GreaterThanOrEqual(updateVersion) {
log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", m.currentVersion, updateVersion)
return false
}
if time.Since(m.lastTrigger) < 5*time.Minute {
log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
return false
}
return true
}
func (m *Manager) lastResultErrReason() string {
inst := installer.New()
result := installer.NewResultHandler(inst.TempDir())
return result.GetErrorResultReason()
}
func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
inst := installer.New()
return inst.RunInstallation(ctx, targetVersion)
}

View File

@@ -0,0 +1,214 @@
//go:build windows || darwin
package updatemanager
import (
"context"
"fmt"
"path"
"testing"
"time"
v "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type versionUpdateMock struct {
latestVersion *v.Version
onUpdate func()
}
func (v versionUpdateMock) StopWatch() {}
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
return false
}
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
v.onUpdate = updateFn
}
func (v versionUpdateMock) LatestVersion() *v.Version {
return v.latestVersion
}
func (v versionUpdateMock) StartFetcher() {}
func Test_LatestVersion(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
initialLatestVersion *v.Version
latestVersion *v.Version
shouldUpdateInit bool
shouldUpdateLater bool
}{
{
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
daemonVersion: "1.0.0",
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
latestVersion: v.Must(v.NewSemver("1.0.2")),
shouldUpdateInit: true,
shouldUpdateLater: false,
},
{
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
daemonVersion: "1.0.0",
initialLatestVersion: nil,
latestVersion: v.Must(v.NewSemver("1.0.1")),
shouldUpdateInit: false,
shouldUpdateLater: true,
},
}
for idx, c := range testMatrix {
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = mockUpdate
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion("latest")
var triggeredInit bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.initialLatestVersion.String() {
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
}
triggeredInit = true
case <-time.After(10 * time.Millisecond):
triggeredInit = false
}
if triggeredInit != c.shouldUpdateInit {
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
}
mockUpdate.latestVersion = c.latestVersion
mockUpdate.onUpdate()
var triggeredLater bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
}
triggeredLater = true
case <-time.After(10 * time.Millisecond):
triggeredLater = false
}
if triggeredLater != c.shouldUpdateLater {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
}
m.Stop()
}
}
func Test_HandleUpdate(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
latestVersion *v.Version
expectedVersion string
shouldUpdate bool
}{
{
name: "Update to a specific version should update regardless of if latestVersion is available yet",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.56.0",
shouldUpdate: true,
},
{
name: "Update to specific version should not update if version matches",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.55.0",
shouldUpdate: false,
},
{
name: "Update to specific version should not update if current version is newer",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.54.0",
shouldUpdate: false,
},
{
name: "Update to latest version should update if latest is newer",
daemonVersion: "0.55.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: true,
},
{
name: "Update to latest version should not update if latest == current",
daemonVersion: "0.56.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if daemon version is invalid",
daemonVersion: "development",
latestVersion: v.Must(v.NewSemver("1.0.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expecting latest and latest version is unavailable",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expected version is invalid",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "development",
shouldUpdate: false,
},
}
for idx, c := range testMatrix {
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion(c.expectedVersion)
var updateTriggered bool
select {
case targetVersion := <-targetVersionChan:
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
}
updateTriggered = true
case <-time.After(10 * time.Millisecond):
updateTriggered = false
}
if updateTriggered != c.shouldUpdate {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
}
m.Stop()
}
}

View File

@@ -0,0 +1,39 @@
//go:build !windows && !darwin
package updatemanager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Manager is a no-op stub for unsupported platforms
type Manager struct{}
// NewManager returns a no-op manager for unsupported platforms
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
return nil, fmt.Errorf("update manager is not supported on this platform")
}
// CheckUpdateSuccess is a no-op on unsupported platforms
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
// no-op
}
// Start is a no-op on unsupported platforms
func (m *Manager) Start(ctx context.Context) {
// no-op
}
// SetVersion is a no-op on unsupported platforms
func (m *Manager) SetVersion(expectedVersion string) {
// no-op
}
// Stop is a no-op on unsupported platforms
func (m *Manager) Stop() {
// no-op
}

View File

@@ -0,0 +1,302 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"hash"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/blake2s"
)
const (
tagArtifactPrivate = "ARTIFACT PRIVATE KEY"
tagArtifactPublic = "ARTIFACT PUBLIC KEY"
maxArtifactKeySignatureAge = 10 * 365 * 24 * time.Hour
maxArtifactSignatureAge = 10 * 365 * 24 * time.Hour
)
// ArtifactHash wraps a hash.Hash and counts bytes written
type ArtifactHash struct {
hash.Hash
}
// NewArtifactHash returns an initialized ArtifactHash using BLAKE2s
func NewArtifactHash() *ArtifactHash {
h, err := blake2s.New256(nil)
if err != nil {
panic(err) // Should never happen with nil Key
}
return &ArtifactHash{Hash: h}
}
func (ah *ArtifactHash) Write(b []byte) (int, error) {
return ah.Hash.Write(b)
}
// ArtifactKey is a signing Key used to sign artifacts
type ArtifactKey struct {
PrivateKey
}
func (k ArtifactKey) String() string {
return fmt.Sprintf(
"ArtifactKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
k.Metadata.ID,
k.Metadata.CreatedAt.Format(time.RFC3339),
k.Metadata.ExpiresAt.Format(time.RFC3339),
)
}
func GenerateArtifactKey(rootKey *RootKey, expiration time.Duration) (*ArtifactKey, []byte, []byte, []byte, error) {
// Verify root key is still valid
if !rootKey.Metadata.ExpiresAt.IsZero() && time.Now().After(rootKey.Metadata.ExpiresAt) {
return nil, nil, nil, nil, fmt.Errorf("root key has expired on %s", rootKey.Metadata.ExpiresAt.Format(time.RFC3339))
}
now := time.Now()
expirationTime := now.Add(expiration)
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("generate ed25519 key: %w", err)
}
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: now.UTC(),
ExpiresAt: expirationTime.UTC(),
}
ak := &ArtifactKey{
PrivateKey{
Key: priv,
Metadata: metadata,
},
}
// Marshal PrivateKey struct to JSON
privJSON, err := json.Marshal(ak.PrivateKey)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
}
// Marshal PublicKey struct to JSON
pubKey := PublicKey{
Key: pub,
Metadata: metadata,
}
pubJSON, err := json.Marshal(pubKey)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
}
// Encode to PEM with metadata embedded in bytes
privPEM := pem.EncodeToMemory(&pem.Block{
Type: tagArtifactPrivate,
Bytes: privJSON,
})
pubPEM := pem.EncodeToMemory(&pem.Block{
Type: tagArtifactPublic,
Bytes: pubJSON,
})
// Sign the public key with the root key
signature, err := SignArtifactKey(*rootKey, pubPEM)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to sign artifact key: %w", err)
}
return ak, privPEM, pubPEM, signature, nil
}
func ParseArtifactKey(privKeyPEM []byte) (ArtifactKey, error) {
pk, err := parsePrivateKey(privKeyPEM, tagArtifactPrivate)
if err != nil {
return ArtifactKey{}, fmt.Errorf("failed to parse artifact Key: %w", err)
}
return ArtifactKey{pk}, nil
}
func ParseArtifactPubKey(data []byte) (PublicKey, error) {
pk, _, err := parsePublicKey(data, tagArtifactPublic)
return pk, err
}
func BundleArtifactKeys(rootKey *RootKey, keys []PublicKey) ([]byte, []byte, error) {
if len(keys) == 0 {
return nil, nil, errors.New("no keys to bundle")
}
// Create bundle by concatenating PEM-encoded keys
var pubBundle []byte
for _, pk := range keys {
// Marshal PublicKey struct to JSON
pubJSON, err := json.Marshal(pk)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
}
// Encode to PEM
pubPEM := pem.EncodeToMemory(&pem.Block{
Type: tagArtifactPublic,
Bytes: pubJSON,
})
pubBundle = append(pubBundle, pubPEM...)
}
// Sign the entire bundle with the root key
signature, err := SignArtifactKey(*rootKey, pubBundle)
if err != nil {
return nil, nil, fmt.Errorf("failed to sign artifact key bundle: %w", err)
}
return pubBundle, signature, nil
}
func ValidateArtifactKeys(publicRootKeys []PublicKey, data []byte, signature Signature, revocationList *RevocationList) ([]PublicKey, error) {
now := time.Now().UTC()
if signature.Timestamp.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("signature timestamp is in the future: %v", signature.Timestamp)
log.Debugf("artifact signature error: %v", err)
return nil, err
}
if now.Sub(signature.Timestamp) > maxArtifactKeySignatureAge {
err := fmt.Errorf("signature is too old: %v (created %v)", now.Sub(signature.Timestamp), signature.Timestamp)
log.Debugf("artifact signature error: %v", err)
return nil, err
}
// Reconstruct the signed message: artifact_key_data || timestamp
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
if !verifyAny(publicRootKeys, msg, signature.Signature) {
return nil, errors.New("failed to verify signature of artifact keys")
}
pubKeys, err := parsePublicKeyBundle(data, tagArtifactPublic)
if err != nil {
log.Debugf("failed to parse public keys: %s", err)
return nil, err
}
validKeys := make([]PublicKey, 0, len(pubKeys))
for _, pubKey := range pubKeys {
// Filter out expired keys
if !pubKey.Metadata.ExpiresAt.IsZero() && now.After(pubKey.Metadata.ExpiresAt) {
log.Debugf("Key %s is expired at %v (current time %v)",
pubKey.Metadata.ID, pubKey.Metadata.ExpiresAt, now)
continue
}
if revocationList != nil {
if revTime, revoked := revocationList.Revoked[pubKey.Metadata.ID]; revoked {
log.Debugf("Key %s is revoked as of %v (created %v)",
pubKey.Metadata.ID, revTime, pubKey.Metadata.CreatedAt)
continue
}
}
validKeys = append(validKeys, pubKey)
}
if len(validKeys) == 0 {
log.Debugf("no valid public keys found for artifact keys")
return nil, fmt.Errorf("all %d artifact keys are revoked", len(pubKeys))
}
return validKeys, nil
}
func ValidateArtifact(artifactPubKeys []PublicKey, data []byte, signature Signature) error {
// Validate signature timestamp
now := time.Now().UTC()
if signature.Timestamp.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("artifact signature timestamp is in the future: %v", signature.Timestamp)
log.Debugf("failed to verify signature of artifact: %s", err)
return err
}
if now.Sub(signature.Timestamp) > maxArtifactSignatureAge {
return fmt.Errorf("artifact signature is too old: %v (created %v)",
now.Sub(signature.Timestamp), signature.Timestamp)
}
h := NewArtifactHash()
if _, err := h.Write(data); err != nil {
return fmt.Errorf("failed to hash artifact: %w", err)
}
hash := h.Sum(nil)
// Reconstruct the signed message: hash || length || timestamp
msg := make([]byte, 0, len(hash)+8+8)
msg = append(msg, hash...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
// Find matching Key and verify
for _, keyInfo := range artifactPubKeys {
if keyInfo.Metadata.ID == signature.KeyID {
// Check Key expiration
if !keyInfo.Metadata.ExpiresAt.IsZero() &&
signature.Timestamp.After(keyInfo.Metadata.ExpiresAt) {
return fmt.Errorf("signing Key %s expired at %v, signature from %v",
signature.KeyID, keyInfo.Metadata.ExpiresAt, signature.Timestamp)
}
if ed25519.Verify(keyInfo.Key, msg, signature.Signature) {
log.Debugf("artifact verified successfully with Key: %s", signature.KeyID)
return nil
}
return fmt.Errorf("signature verification failed for Key %s", signature.KeyID)
}
}
return fmt.Errorf("no signing Key found with ID %s", signature.KeyID)
}
func SignData(artifactKey ArtifactKey, data []byte) ([]byte, error) {
if len(data) == 0 { // Check happens too late
return nil, fmt.Errorf("artifact length must be positive, got %d", len(data))
}
h := NewArtifactHash()
if _, err := h.Write(data); err != nil {
return nil, fmt.Errorf("failed to write artifact hash: %w", err)
}
timestamp := time.Now().UTC()
if !artifactKey.Metadata.ExpiresAt.IsZero() && timestamp.After(artifactKey.Metadata.ExpiresAt) {
return nil, fmt.Errorf("artifact key expired at %v", artifactKey.Metadata.ExpiresAt)
}
hash := h.Sum(nil)
// Create message: hash || length || timestamp
msg := make([]byte, 0, len(hash)+8+8)
msg = append(msg, hash...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
sig := ed25519.Sign(artifactKey.Key, msg)
bundle := Signature{
Signature: sig,
Timestamp: timestamp,
KeyID: artifactKey.Metadata.ID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
return json.Marshal(bundle)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
-----BEGIN ROOT PUBLIC KEY-----
eyJLZXkiOiJoaGIxdGRDSEZNMFBuQWp1b2w2cXJ1QXRFbWFFSlg1QjFsZUNxWmpn
V1pvPSIsIk1ldGFkYXRhIjp7ImlkIjoiOWE0OTg2NmI2MzE2MjNiNCIsImNyZWF0
ZWRfYXQiOiIyMDI1LTExLTI0VDE3OjE1OjI4LjYyNzE3MzE3MVoiLCJleHBpcmVz
X2F0IjoiMjAzNS0xMS0yMlQxNzoxNToyOC42MjcxNzMxNzFaIn19
-----END ROOT PUBLIC KEY-----

View File

@@ -0,0 +1,6 @@
-----BEGIN ROOT PUBLIC KEY-----
eyJLZXkiOiJyTDByVTN2MEFOZUNmbDZraitiUUd3TE1waU5CaUJLdVBWSnZtQzgr
ZS84PSIsIk1ldGFkYXRhIjp7ImlkIjoiMTBkNjQyZTY2N2FmMDNkNCIsImNyZWF0
ZWRfYXQiOiIyMDI1LTExLTIwVDE3OjI5OjI5LjE4MDk0NjMxNloiLCJleHBpcmVz
X2F0IjoiMjAyNi0xMS0yMFQxNzoyOToyOS4xODA5NDYzMTZaIn19
-----END ROOT PUBLIC KEY-----

View File

@@ -0,0 +1,174 @@
// Package reposign implements a cryptographic signing and verification system
// for NetBird software update artifacts. It provides a hierarchical key
// management system with support for key rotation, revocation, and secure
// artifact distribution.
//
// # Architecture
//
// The package uses a two-tier key hierarchy:
//
// - Root Keys: Long-lived keys that sign artifact keys. These are embedded
// in the client binary and establish the root of trust. Root keys should
// be kept offline and highly secured.
//
// - Artifact Keys: Short-lived keys that sign release artifacts (binaries,
// packages, etc.). These are rotated regularly and can be revoked if
// compromised. Artifact keys are signed by root keys and distributed via
// a public repository.
//
// This separation allows for operational flexibility: artifact keys can be
// rotated frequently without requiring client updates, while root keys remain
// stable and embedded in the software.
//
// # Cryptographic Primitives
//
// The package uses strong, modern cryptographic algorithms:
// - Ed25519: Fast, secure digital signatures (no timing attacks)
// - BLAKE2s-256: Fast cryptographic hash for artifacts
// - SHA-256: Key ID generation
// - JSON: Structured key and signature serialization
// - PEM: Standard key encoding format
//
// # Security Features
//
// Timestamp Binding:
// - All signatures include cryptographically-bound timestamps
// - Prevents replay attacks and enforces signature freshness
// - Clock skew tolerance: 5 minutes
//
// Key Expiration:
// - All keys have expiration times
// - Expired keys are automatically rejected
// - Signing with an expired key fails immediately
//
// Key Revocation:
// - Compromised keys can be revoked via a signed revocation list
// - Revocation list is checked during artifact validation
// - Revoked keys are filtered out before artifact verification
//
// # File Structure
//
// The package expects the following file layout in the key repository:
//
// signrepo/
// artifact-key-pub.pem # Bundle of artifact public keys
// artifact-key-pub.pem.sig # Root signature of the bundle
// revocation-list.json # List of revoked key IDs
// revocation-list.json.sig # Root signature of revocation list
//
// And in the artifacts repository:
//
// releases/
// v0.28.0/
// netbird-linux-amd64
// netbird-linux-amd64.sig # Artifact signature
// netbird-darwin-amd64
// netbird-darwin-amd64.sig
// ...
//
// # Embedded Root Keys
//
// Root public keys are embedded in the client binary at compile time:
// - Production keys: certs/ directory
// - Development keys: certsdev/ directory
//
// The build tag determines which keys are embedded:
// - Production builds: //go:build !devartifactsign
// - Development builds: //go:build devartifactsign
//
// This ensures that development artifacts cannot be verified using production
// keys and vice versa.
//
// # Key Rotation Strategies
//
// Root Key Rotation:
//
// Root keys can be rotated without breaking existing clients by leveraging
// the multi-key verification system. The loadEmbeddedPublicKeys function
// reads ALL files from the certs/ directory and accepts signatures from ANY
// of the embedded root keys.
//
// To rotate root keys:
//
// 1. Generate a new root key pair:
// newRootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
//
// 2. Add the new public key to the certs/ directory as a new file:
// certs/
// root-pub-2024.pem # Old key (keep this!)
// root-pub-2025.pem # New key (add this)
//
// 3. Build new client versions with both keys embedded. The verification
// will accept signatures from either key.
//
// 4. Start signing new artifact keys with the new root key. Old clients
// with only the old root key will reject these, but new clients with
// both keys will accept them.
//
// Each file in certs/ can contain a single key or a bundle of keys (multiple
// PEM blocks). The system will parse all keys from all files and use them
// for verification. This provides maximum flexibility for key management.
//
// Important: Never remove all old root keys at once. Always maintain at least
// one overlapping key between releases to ensure smooth transitions.
//
// Artifact Key Rotation:
//
// Artifact keys should be rotated regularly (e.g., every 90 days) using the
// bundling mechanism. The BundleArtifactKeys function allows multiple artifact
// keys to be bundled together in a single signed package, and ValidateArtifact
// will accept signatures from ANY key in the bundle.
//
// To rotate artifact keys smoothly:
//
// 1. Generate a new artifact key while keeping the old one:
// newKey, newPrivPEM, newPubPEM, newSig, err := GenerateArtifactKey(rootKey, 90 * 24 * time.Hour)
// // Keep oldPubPEM and oldKey available
//
// 2. Create a bundle containing both old and new public keys
//
// 3. Upload the bundle and its signature to the key repository:
// signrepo/artifact-key-pub.pem # Contains both keys
// signrepo/artifact-key-pub.pem.sig # Root signature
//
// 4. Start signing new releases with the NEW key, but keep the bundle
// unchanged. Clients will download the bundle (containing both keys)
// and accept signatures from either key.
//
// Key bundle validation workflow:
// 1. Client downloads artifact-key-pub.pem and artifact-key-pub.pem.sig
// 2. ValidateArtifactKeys verifies the bundle signature with ANY embedded root key
// 3. ValidateArtifactKeys parses all public keys from the bundle
// 4. ValidateArtifactKeys filters out expired or revoked keys
// 5. When verifying an artifact, ValidateArtifact tries each key until one succeeds
//
// This multi-key acceptance model enables overlapping validity periods and
// smooth transitions without client update requirements.
//
// # Best Practices
//
// Root Key Management:
// - Generate root keys offline on an air-gapped machine
// - Store root private keys in hardware security modules (HSM) if possible
// - Use separate root keys for production and development
// - Rotate root keys infrequently (e.g., every 5-10 years)
// - Plan for root key rotation: embed multiple root public keys
//
// Artifact Key Management:
// - Rotate artifact keys regularly (e.g., every 90 days)
// - Use separate artifact keys for different release channels if needed
// - Revoke keys immediately upon suspected compromise
// - Bundle multiple artifact keys to enable smooth rotation
//
// Signing Process:
// - Sign artifacts in a secure CI/CD environment
// - Never commit private keys to version control
// - Use environment variables or secret management for keys
// - Verify signatures immediately after signing
//
// Distribution:
// - Serve keys and revocation lists from a reliable CDN
// - Use HTTPS for all key and artifact downloads
// - Monitor download failures and signature verification failures
// - Keep revocation list up to date
package reposign

View File

@@ -0,0 +1,10 @@
//go:build devartifactsign
package reposign
import "embed"
//go:embed certsdev
var embeddedCerts embed.FS
const embeddedCertsDir = "certsdev"

View File

@@ -0,0 +1,10 @@
//go:build !devartifactsign
package reposign
import "embed"
//go:embed certs
var embeddedCerts embed.FS
const embeddedCertsDir = "certs"

View File

@@ -0,0 +1,171 @@
package reposign
import (
"crypto/ed25519"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"time"
)
const (
maxClockSkew = 5 * time.Minute
)
// KeyID is a unique identifier for a Key (first 8 bytes of SHA-256 of public Key)
type KeyID [8]byte
// computeKeyID generates a unique ID from a public Key
func computeKeyID(pub ed25519.PublicKey) KeyID {
h := sha256.Sum256(pub)
var id KeyID
copy(id[:], h[:8])
return id
}
// MarshalJSON implements json.Marshaler for KeyID
func (k KeyID) MarshalJSON() ([]byte, error) {
return json.Marshal(k.String())
}
// UnmarshalJSON implements json.Unmarshaler for KeyID
func (k *KeyID) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
parsed, err := ParseKeyID(s)
if err != nil {
return err
}
*k = parsed
return nil
}
// ParseKeyID parses a hex string (16 hex chars = 8 bytes) into a KeyID.
func ParseKeyID(s string) (KeyID, error) {
var id KeyID
if len(s) != 16 {
return id, fmt.Errorf("invalid KeyID length: got %d, want 16 hex chars (8 bytes)", len(s))
}
b, err := hex.DecodeString(s)
if err != nil {
return id, fmt.Errorf("failed to decode KeyID: %w", err)
}
copy(id[:], b)
return id, nil
}
func (k KeyID) String() string {
return fmt.Sprintf("%x", k[:])
}
// KeyMetadata contains versioning and lifecycle information for a Key
type KeyMetadata struct {
ID KeyID `json:"id"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at,omitempty"` // Optional expiration
}
// PublicKey wraps a public Key with its Metadata
type PublicKey struct {
Key ed25519.PublicKey
Metadata KeyMetadata
}
func parsePublicKeyBundle(bundle []byte, typeTag string) ([]PublicKey, error) {
var keys []PublicKey
for len(bundle) > 0 {
keyInfo, rest, err := parsePublicKey(bundle, typeTag)
if err != nil {
return nil, err
}
keys = append(keys, keyInfo)
bundle = rest
}
if len(keys) == 0 {
return nil, errors.New("no keys found in bundle")
}
return keys, nil
}
func parsePublicKey(data []byte, typeTag string) (PublicKey, []byte, error) {
b, rest := pem.Decode(data)
if b == nil {
return PublicKey{}, nil, errors.New("failed to decode PEM data")
}
if b.Type != typeTag {
return PublicKey{}, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
}
// Unmarshal JSON-embedded format
var pub PublicKey
if err := json.Unmarshal(b.Bytes, &pub); err != nil {
return PublicKey{}, nil, fmt.Errorf("failed to unmarshal public key: %w", err)
}
// Validate key length
if len(pub.Key) != ed25519.PublicKeySize {
return PublicKey{}, nil, fmt.Errorf("incorrect Ed25519 public key size: expected %d, got %d",
ed25519.PublicKeySize, len(pub.Key))
}
// Always recompute ID to ensure integrity
pub.Metadata.ID = computeKeyID(pub.Key)
return pub, rest, nil
}
type PrivateKey struct {
Key ed25519.PrivateKey
Metadata KeyMetadata
}
func parsePrivateKey(data []byte, typeTag string) (PrivateKey, error) {
b, rest := pem.Decode(data)
if b == nil {
return PrivateKey{}, errors.New("failed to decode PEM data")
}
if len(rest) > 0 {
return PrivateKey{}, errors.New("trailing PEM data")
}
if b.Type != typeTag {
return PrivateKey{}, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
}
// Unmarshal JSON-embedded format
var pk PrivateKey
if err := json.Unmarshal(b.Bytes, &pk); err != nil {
return PrivateKey{}, fmt.Errorf("failed to unmarshal private key: %w", err)
}
// Validate key length
if len(pk.Key) != ed25519.PrivateKeySize {
return PrivateKey{}, fmt.Errorf("incorrect Ed25519 private key size: expected %d, got %d",
ed25519.PrivateKeySize, len(pk.Key))
}
return pk, nil
}
func verifyAny(publicRootKeys []PublicKey, msg, sig []byte) bool {
// Verify with root keys
var rootKeys []ed25519.PublicKey
for _, r := range publicRootKeys {
rootKeys = append(rootKeys, r.Key)
}
for _, k := range rootKeys {
if ed25519.Verify(k, msg, sig) {
return true
}
}
return false
}

View File

@@ -0,0 +1,636 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/json"
"encoding/pem"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test KeyID functions
func TestComputeKeyID(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
// Verify it's the first 8 bytes of SHA-256
h := sha256.Sum256(pub)
expectedID := KeyID{}
copy(expectedID[:], h[:8])
assert.Equal(t, expectedID, keyID)
}
func TestComputeKeyID_Deterministic(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
// Computing KeyID multiple times should give the same result
keyID1 := computeKeyID(pub)
keyID2 := computeKeyID(pub)
assert.Equal(t, keyID1, keyID2)
}
func TestComputeKeyID_DifferentKeys(t *testing.T) {
pub1, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID1 := computeKeyID(pub1)
keyID2 := computeKeyID(pub2)
// Different keys should produce different IDs
assert.NotEqual(t, keyID1, keyID2)
}
func TestParseKeyID_Valid(t *testing.T) {
hexStr := "0123456789abcdef"
keyID, err := ParseKeyID(hexStr)
require.NoError(t, err)
expected := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
assert.Equal(t, expected, keyID)
}
func TestParseKeyID_InvalidLength(t *testing.T) {
tests := []struct {
name string
input string
}{
{"too short", "01234567"},
{"too long", "0123456789abcdef00"},
{"empty", ""},
{"odd length", "0123456789abcde"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ParseKeyID(tt.input)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid KeyID length")
})
}
}
func TestParseKeyID_InvalidHex(t *testing.T) {
invalidHex := "0123456789abcxyz" // 'xyz' are not valid hex
_, err := ParseKeyID(invalidHex)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode KeyID")
}
func TestKeyID_String(t *testing.T) {
keyID := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
str := keyID.String()
assert.Equal(t, "0123456789abcdef", str)
}
func TestKeyID_RoundTrip(t *testing.T) {
original := "fedcba9876543210"
keyID, err := ParseKeyID(original)
require.NoError(t, err)
result := keyID.String()
assert.Equal(t, original, result)
}
func TestKeyID_ZeroValue(t *testing.T) {
keyID := KeyID{}
str := keyID.String()
assert.Equal(t, "0000000000000000", str)
}
// Test KeyMetadata
func TestKeyMetadata_JSONMarshaling(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
ExpiresAt: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC),
}
jsonData, err := json.Marshal(metadata)
require.NoError(t, err)
var decoded KeyMetadata
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Equal(t, metadata.ID, decoded.ID)
assert.Equal(t, metadata.CreatedAt.Unix(), decoded.CreatedAt.Unix())
assert.Equal(t, metadata.ExpiresAt.Unix(), decoded.ExpiresAt.Unix())
}
func TestKeyMetadata_NoExpiration(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
ExpiresAt: time.Time{}, // Zero value = no expiration
}
jsonData, err := json.Marshal(metadata)
require.NoError(t, err)
var decoded KeyMetadata
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.True(t, decoded.ExpiresAt.IsZero())
}
// Test PublicKey
func TestPublicKey_JSONMarshaling(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
var decoded PublicKey
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Equal(t, pubKey.Key, decoded.Key)
assert.Equal(t, pubKey.Metadata.ID, decoded.Metadata.ID)
}
// Test parsePublicKey
func TestParsePublicKey_Valid(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
}
pubKey := PublicKey{
Key: pub,
Metadata: metadata,
}
// Marshal to JSON
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
// Encode to PEM
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
// Parse it back
parsed, rest, err := parsePublicKey(pemData, tagRootPublic)
require.NoError(t, err)
assert.Empty(t, rest)
assert.Equal(t, pub, parsed.Key)
assert.Equal(t, metadata.ID, parsed.Metadata.ID)
}
func TestParsePublicKey_InvalidPEM(t *testing.T) {
invalidPEM := []byte("not a PEM")
_, _, err := parsePublicKey(invalidPEM, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode PEM")
}
func TestParsePublicKey_WrongType(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
// Encode with wrong type
pemData := pem.EncodeToMemory(&pem.Block{
Type: "WRONG TYPE",
Bytes: jsonData,
})
_, _, err = parsePublicKey(pemData, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "PEM type")
}
func TestParsePublicKey_InvalidJSON(t *testing.T) {
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: []byte("invalid json"),
})
_, _, err := parsePublicKey(pemData, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to unmarshal")
}
func TestParsePublicKey_InvalidKeySize(t *testing.T) {
// Create a public key with wrong size
pubKey := PublicKey{
Key: []byte{0x01, 0x02, 0x03}, // Too short
Metadata: KeyMetadata{
ID: KeyID{},
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
_, _, err = parsePublicKey(pemData, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "incorrect Ed25519 public key size")
}
func TestParsePublicKey_IDRecomputation(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
// Create a public key with WRONG ID
wrongID := KeyID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: wrongID,
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
// Parse should recompute the correct ID
parsed, _, err := parsePublicKey(pemData, tagRootPublic)
require.NoError(t, err)
correctID := computeKeyID(pub)
assert.Equal(t, correctID, parsed.Metadata.ID)
assert.NotEqual(t, wrongID, parsed.Metadata.ID)
}
// Test parsePublicKeyBundle
func TestParsePublicKeyBundle_Single(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
keys, err := parsePublicKeyBundle(pemData, tagRootPublic)
require.NoError(t, err)
assert.Len(t, keys, 1)
assert.Equal(t, pub, keys[0].Key)
}
func TestParsePublicKeyBundle_Multiple(t *testing.T) {
var bundle []byte
// Create 3 keys
for i := 0; i < 3; i++ {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := PublicKey{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(pubKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: jsonData,
})
bundle = append(bundle, pemData...)
}
keys, err := parsePublicKeyBundle(bundle, tagRootPublic)
require.NoError(t, err)
assert.Len(t, keys, 3)
}
func TestParsePublicKeyBundle_Empty(t *testing.T) {
_, err := parsePublicKeyBundle([]byte{}, tagRootPublic)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no keys found")
}
func TestParsePublicKeyBundle_Invalid(t *testing.T) {
_, err := parsePublicKeyBundle([]byte("invalid data"), tagRootPublic)
assert.Error(t, err)
}
// Test PrivateKey
func TestPrivateKey_JSONMarshaling(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
var decoded PrivateKey
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Equal(t, privKey.Key, decoded.Key)
assert.Equal(t, privKey.Metadata.ID, decoded.Metadata.ID)
}
// Test parsePrivateKey
func TestParsePrivateKey_Valid(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: jsonData,
})
parsed, err := parsePrivateKey(pemData, tagRootPrivate)
require.NoError(t, err)
assert.Equal(t, priv, parsed.Key)
}
func TestParsePrivateKey_InvalidPEM(t *testing.T) {
_, err := parsePrivateKey([]byte("not a PEM"), tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode PEM")
}
func TestParsePrivateKey_TrailingData(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: jsonData,
})
// Add trailing data
pemData = append(pemData, []byte("extra data")...)
_, err = parsePrivateKey(pemData, tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "trailing PEM data")
}
func TestParsePrivateKey_WrongType(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
privKey := PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: "WRONG TYPE",
Bytes: jsonData,
})
_, err = parsePrivateKey(pemData, tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "PEM type")
}
func TestParsePrivateKey_InvalidKeySize(t *testing.T) {
privKey := PrivateKey{
Key: []byte{0x01, 0x02, 0x03}, // Too short
Metadata: KeyMetadata{
ID: KeyID{},
CreatedAt: time.Now().UTC(),
},
}
jsonData, err := json.Marshal(privKey)
require.NoError(t, err)
pemData := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: jsonData,
})
_, err = parsePrivateKey(pemData, tagRootPrivate)
assert.Error(t, err)
assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
}
// Test verifyAny
func TestVerifyAny_ValidSignature(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv, message)
rootKeys := []PublicKey{
{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
},
}
result := verifyAny(rootKeys, message, signature)
assert.True(t, result)
}
func TestVerifyAny_InvalidSignature(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
invalidSignature := make([]byte, ed25519.SignatureSize)
rootKeys := []PublicKey{
{
Key: pub,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: time.Now().UTC(),
},
},
}
result := verifyAny(rootKeys, message, invalidSignature)
assert.False(t, result)
}
func TestVerifyAny_MultipleKeys(t *testing.T) {
// Create 3 key pairs
pub1, priv1, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub3, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv1, message)
rootKeys := []PublicKey{
{Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
{Key: pub1, Metadata: KeyMetadata{ID: computeKeyID(pub1)}}, // Correct key in middle
{Key: pub3, Metadata: KeyMetadata{ID: computeKeyID(pub3)}},
}
result := verifyAny(rootKeys, message, signature)
assert.True(t, result)
}
func TestVerifyAny_NoMatchingKey(t *testing.T) {
_, priv1, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv1, message)
// Only include pub2, not pub1
rootKeys := []PublicKey{
{Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
}
result := verifyAny(rootKeys, message, signature)
assert.False(t, result)
}
func TestVerifyAny_EmptyKeys(t *testing.T) {
message := []byte("test message")
signature := make([]byte, ed25519.SignatureSize)
result := verifyAny([]PublicKey{}, message, signature)
assert.False(t, result)
}
func TestVerifyAny_TamperedMessage(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
message := []byte("test message")
signature := ed25519.Sign(priv, message)
rootKeys := []PublicKey{
{Key: pub, Metadata: KeyMetadata{ID: computeKeyID(pub)}},
}
// Verify with different message
tamperedMessage := []byte("different message")
result := verifyAny(rootKeys, tamperedMessage, signature)
assert.False(t, result)
}

View File

@@ -0,0 +1,229 @@
package reposign
import (
"crypto/ed25519"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"time"
log "github.com/sirupsen/logrus"
)
const (
maxRevocationSignatureAge = 10 * 365 * 24 * time.Hour
defaultRevocationListExpiration = 365 * 24 * time.Hour
)
type RevocationList struct {
Revoked map[KeyID]time.Time `json:"revoked"` // KeyID -> revocation time
LastUpdated time.Time `json:"last_updated"` // When the list was last modified
ExpiresAt time.Time `json:"expires_at"` // When the list expires
}
func (rl RevocationList) MarshalJSON() ([]byte, error) {
// Convert map[KeyID]time.Time to map[string]time.Time
strMap := make(map[string]time.Time, len(rl.Revoked))
for k, v := range rl.Revoked {
strMap[k.String()] = v
}
return json.Marshal(map[string]interface{}{
"revoked": strMap,
"last_updated": rl.LastUpdated,
"expires_at": rl.ExpiresAt,
})
}
func (rl *RevocationList) UnmarshalJSON(data []byte) error {
var temp struct {
Revoked map[string]time.Time `json:"revoked"`
LastUpdated time.Time `json:"last_updated"`
ExpiresAt time.Time `json:"expires_at"`
Version int `json:"version"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Convert map[string]time.Time back to map[KeyID]time.Time
rl.Revoked = make(map[KeyID]time.Time, len(temp.Revoked))
for k, v := range temp.Revoked {
kid, err := ParseKeyID(k)
if err != nil {
return fmt.Errorf("failed to parse KeyID %q: %w", k, err)
}
rl.Revoked[kid] = v
}
rl.LastUpdated = temp.LastUpdated
rl.ExpiresAt = temp.ExpiresAt
return nil
}
func ParseRevocationList(data []byte) (*RevocationList, error) {
var rl RevocationList
if err := json.Unmarshal(data, &rl); err != nil {
return nil, fmt.Errorf("failed to unmarshal revocation list: %w", err)
}
// Initialize the map if it's nil (in case of empty JSON object)
if rl.Revoked == nil {
rl.Revoked = make(map[KeyID]time.Time)
}
if rl.LastUpdated.IsZero() {
return nil, fmt.Errorf("revocation list missing last_updated timestamp")
}
if rl.ExpiresAt.IsZero() {
return nil, fmt.Errorf("revocation list missing expires_at timestamp")
}
return &rl, nil
}
func ValidateRevocationList(publicRootKeys []PublicKey, data []byte, signature Signature) (*RevocationList, error) {
revoList, err := ParseRevocationList(data)
if err != nil {
log.Debugf("failed to parse revocation list: %s", err)
return nil, err
}
now := time.Now().UTC()
// Validate signature timestamp
if signature.Timestamp.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("revocation signature timestamp is in the future: %v", signature.Timestamp)
log.Debugf("revocation list signature error: %v", err)
return nil, err
}
if now.Sub(signature.Timestamp) > maxRevocationSignatureAge {
err := fmt.Errorf("revocation list signature is too old: %v (created %v)",
now.Sub(signature.Timestamp), signature.Timestamp)
log.Debugf("revocation list signature error: %v", err)
return nil, err
}
// Ensure LastUpdated is not in the future (with clock skew tolerance)
if revoList.LastUpdated.After(now.Add(maxClockSkew)) {
err := fmt.Errorf("revocation list LastUpdated is in the future: %v", revoList.LastUpdated)
log.Errorf("rejecting future-dated revocation list: %v", err)
return nil, err
}
// Check if the revocation list has expired
if now.After(revoList.ExpiresAt) {
err := fmt.Errorf("revocation list expired at %v (current time: %v)", revoList.ExpiresAt, now)
log.Errorf("rejecting expired revocation list: %v", err)
return nil, err
}
// Ensure ExpiresAt is not in the future by more than the expected expiration window
// (allows some clock skew but prevents maliciously long expiration times)
if revoList.ExpiresAt.After(now.Add(maxRevocationSignatureAge)) {
err := fmt.Errorf("revocation list ExpiresAt is too far in the future: %v", revoList.ExpiresAt)
log.Errorf("rejecting revocation list with invalid expiration: %v", err)
return nil, err
}
// Validate signature timestamp is close to LastUpdated
// (prevents signing old lists with new timestamps)
timeDiff := signature.Timestamp.Sub(revoList.LastUpdated).Abs()
if timeDiff > maxClockSkew {
err := fmt.Errorf("signature timestamp %v differs too much from list LastUpdated %v (diff: %v)",
signature.Timestamp, revoList.LastUpdated, timeDiff)
log.Errorf("timestamp mismatch in revocation list: %v", err)
return nil, err
}
// Reconstruct the signed message: revocation_list_data || timestamp || version
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
if !verifyAny(publicRootKeys, msg, signature.Signature) {
return nil, errors.New("revocation list verification failed")
}
return revoList, nil
}
func CreateRevocationList(privateRootKey RootKey, expiration time.Duration) ([]byte, []byte, error) {
now := time.Now()
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: now.UTC(),
ExpiresAt: now.Add(expiration).UTC(),
}
signature, err := signRevocationList(privateRootKey, rl)
if err != nil {
return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err)
}
rlData, err := json.Marshal(&rl)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err)
}
signData, err := json.Marshal(signature)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal signature: %w", err)
}
return rlData, signData, nil
}
func ExtendRevocationList(privateRootKey RootKey, rl RevocationList, kid KeyID, expiration time.Duration) ([]byte, []byte, error) {
now := time.Now().UTC()
rl.Revoked[kid] = now
rl.LastUpdated = now
rl.ExpiresAt = now.Add(expiration)
signature, err := signRevocationList(privateRootKey, rl)
if err != nil {
return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err)
}
rlData, err := json.Marshal(&rl)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err)
}
signData, err := json.Marshal(signature)
if err != nil {
return nil, nil, fmt.Errorf("failed to marshal signature: %w", err)
}
return rlData, signData, nil
}
func signRevocationList(privateRootKey RootKey, rl RevocationList) (*Signature, error) {
data, err := json.Marshal(rl)
if err != nil {
return nil, fmt.Errorf("failed to marshal revocation list for signing: %w", err)
}
timestamp := time.Now().UTC()
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
sig := ed25519.Sign(privateRootKey.Key, msg)
signature := &Signature{
Signature: sig,
Timestamp: timestamp,
KeyID: privateRootKey.Metadata.ID,
Algorithm: "ed25519",
HashAlgo: "sha512",
}
return signature, nil
}

View File

@@ -0,0 +1,860 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test RevocationList marshaling/unmarshaling
func TestRevocationList_MarshalJSON(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
expiresAt := time.Date(2024, 4, 15, 11, 0, 0, 0, time.UTC)
rl := &RevocationList{
Revoked: map[KeyID]time.Time{
keyID: revokedTime,
},
LastUpdated: lastUpdated,
ExpiresAt: expiresAt,
}
jsonData, err := json.Marshal(rl)
require.NoError(t, err)
// Verify it can be unmarshaled back
var decoded map[string]interface{}
err = json.Unmarshal(jsonData, &decoded)
require.NoError(t, err)
assert.Contains(t, decoded, "revoked")
assert.Contains(t, decoded, "last_updated")
assert.Contains(t, decoded, "expires_at")
}
func TestRevocationList_UnmarshalJSON(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
jsonData := map[string]interface{}{
"revoked": map[string]string{
keyID.String(): revokedTime.Format(time.RFC3339),
},
"last_updated": lastUpdated.Format(time.RFC3339),
}
jsonBytes, err := json.Marshal(jsonData)
require.NoError(t, err)
var rl RevocationList
err = json.Unmarshal(jsonBytes, &rl)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 1)
assert.Contains(t, rl.Revoked, keyID)
assert.Equal(t, lastUpdated.Unix(), rl.LastUpdated.Unix())
}
func TestRevocationList_MarshalUnmarshal_Roundtrip(t *testing.T) {
pub1, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pub2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID1 := computeKeyID(pub1)
keyID2 := computeKeyID(pub2)
original := &RevocationList{
Revoked: map[KeyID]time.Time{
keyID1: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
keyID2: time.Date(2024, 2, 20, 14, 45, 0, 0, time.UTC),
},
LastUpdated: time.Date(2024, 2, 20, 15, 0, 0, 0, time.UTC),
}
// Marshal
jsonData, err := original.MarshalJSON()
require.NoError(t, err)
// Unmarshal
var decoded RevocationList
err = decoded.UnmarshalJSON(jsonData)
require.NoError(t, err)
// Verify
assert.Len(t, decoded.Revoked, 2)
assert.Equal(t, original.Revoked[keyID1].Unix(), decoded.Revoked[keyID1].Unix())
assert.Equal(t, original.Revoked[keyID2].Unix(), decoded.Revoked[keyID2].Unix())
assert.Equal(t, original.LastUpdated.Unix(), decoded.LastUpdated.Unix())
}
func TestRevocationList_UnmarshalJSON_InvalidKeyID(t *testing.T) {
jsonData := []byte(`{
"revoked": {
"invalid_key_id": "2024-01-15T10:30:00Z"
},
"last_updated": "2024-01-15T11:00:00Z"
}`)
var rl RevocationList
err := json.Unmarshal(jsonData, &rl)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse KeyID")
}
func TestRevocationList_EmptyRevoked(t *testing.T) {
rl := &RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: time.Now().UTC(),
}
jsonData, err := rl.MarshalJSON()
require.NoError(t, err)
var decoded RevocationList
err = decoded.UnmarshalJSON(jsonData)
require.NoError(t, err)
assert.Empty(t, decoded.Revoked)
assert.NotNil(t, decoded.Revoked)
}
// Test ParseRevocationList
func TestParseRevocationList_Valid(t *testing.T) {
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(pub)
revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
rl := RevocationList{
Revoked: map[KeyID]time.Time{
keyID: revokedTime,
},
LastUpdated: lastUpdated,
ExpiresAt: time.Date(2025, 2, 20, 14, 45, 0, 0, time.UTC),
}
jsonData, err := rl.MarshalJSON()
require.NoError(t, err)
parsed, err := ParseRevocationList(jsonData)
require.NoError(t, err)
assert.NotNil(t, parsed)
assert.Len(t, parsed.Revoked, 1)
assert.Equal(t, lastUpdated.Unix(), parsed.LastUpdated.Unix())
}
func TestParseRevocationList_InvalidJSON(t *testing.T) {
invalidJSON := []byte("not valid json")
_, err := ParseRevocationList(invalidJSON)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to unmarshal")
}
func TestParseRevocationList_MissingLastUpdated(t *testing.T) {
jsonData := []byte(`{
"revoked": {}
}`)
_, err := ParseRevocationList(jsonData)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing last_updated")
}
func TestParseRevocationList_EmptyObject(t *testing.T) {
jsonData := []byte(`{}`)
_, err := ParseRevocationList(jsonData)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing last_updated")
}
func TestParseRevocationList_NilRevoked(t *testing.T) {
lastUpdated := time.Now().UTC()
expiresAt := lastUpdated.Add(90 * 24 * time.Hour)
jsonData := []byte(`{
"last_updated": "` + lastUpdated.Format(time.RFC3339) + `",
"expires_at": "` + expiresAt.Format(time.RFC3339) + `"
}`)
parsed, err := ParseRevocationList(jsonData)
require.NoError(t, err)
assert.NotNil(t, parsed.Revoked)
assert.Empty(t, parsed.Revoked)
}
func TestParseRevocationList_MissingExpiresAt(t *testing.T) {
lastUpdated := time.Now().UTC()
jsonData := []byte(`{
"revoked": {},
"last_updated": "` + lastUpdated.Format(time.RFC3339) + `"
}`)
_, err := ParseRevocationList(jsonData)
assert.Error(t, err)
assert.Contains(t, err.Error(), "missing expires_at")
}
// Test ValidateRevocationList
func TestValidateRevocationList_Valid(t *testing.T) {
// Generate root key
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
signature, err := ParseSignature(sigData)
require.NoError(t, err)
// Validate
rl, err := ValidateRevocationList(rootKeys, rlData, *signature)
require.NoError(t, err)
assert.NotNil(t, rl)
assert.Empty(t, rl.Revoked)
}
func TestValidateRevocationList_InvalidSignature(t *testing.T) {
// Generate root key
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
// Create invalid signature
invalidSig := Signature{
Signature: make([]byte, 64),
Timestamp: time.Now().UTC(),
KeyID: computeKeyID(rootPub),
Algorithm: "ed25519",
HashAlgo: "sha512",
}
// Validate should fail
_, err = ValidateRevocationList(rootKeys, rlData, invalidSig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "verification failed")
}
func TestValidateRevocationList_FutureTimestamp(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
signature, err := ParseSignature(sigData)
require.NoError(t, err)
// Modify timestamp to be in the future
signature.Timestamp = time.Now().UTC().Add(10 * time.Minute)
_, err = ValidateRevocationList(rootKeys, rlData, *signature)
assert.Error(t, err)
assert.Contains(t, err.Error(), "in the future")
}
func TestValidateRevocationList_TooOld(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
signature, err := ParseSignature(sigData)
require.NoError(t, err)
// Modify timestamp to be too old
signature.Timestamp = time.Now().UTC().Add(-20 * 365 * 24 * time.Hour)
_, err = ValidateRevocationList(rootKeys, rlData, *signature)
assert.Error(t, err)
assert.Contains(t, err.Error(), "too old")
}
func TestValidateRevocationList_InvalidJSON(t *testing.T) {
rootPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
signature := Signature{
Signature: make([]byte, 64),
Timestamp: time.Now().UTC(),
KeyID: computeKeyID(rootPub),
Algorithm: "ed25519",
HashAlgo: "sha512",
}
_, err = ValidateRevocationList(rootKeys, []byte("invalid json"), signature)
assert.Error(t, err)
}
func TestValidateRevocationList_FutureLastUpdated(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list with future LastUpdated
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: time.Now().UTC().Add(10 * time.Minute),
ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour),
}
rlData, err := json.Marshal(rl)
require.NoError(t, err)
// Sign it
sig, err := signRevocationList(rootKey, rl)
require.NoError(t, err)
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "LastUpdated is in the future")
}
func TestValidateRevocationList_TimestampMismatch(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list with LastUpdated far in the past
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: time.Now().UTC().Add(-1 * time.Hour),
ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour),
}
rlData, err := json.Marshal(rl)
require.NoError(t, err)
// Sign it with current timestamp
sig, err := signRevocationList(rootKey, rl)
require.NoError(t, err)
// Modify signature timestamp to differ too much from LastUpdated
sig.Timestamp = time.Now().UTC()
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "differs too much")
}
func TestValidateRevocationList_Expired(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list that expired in the past
now := time.Now().UTC()
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: now.Add(-100 * 24 * time.Hour),
ExpiresAt: now.Add(-10 * 24 * time.Hour), // Expired 10 days ago
}
rlData, err := json.Marshal(rl)
require.NoError(t, err)
// Sign it
sig, err := signRevocationList(rootKey, rl)
require.NoError(t, err)
// Adjust signature timestamp to match LastUpdated
sig.Timestamp = rl.LastUpdated
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired")
}
func TestValidateRevocationList_ExpiresAtTooFarInFuture(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list with ExpiresAt too far in the future (beyond maxRevocationSignatureAge)
now := time.Now().UTC()
rl := RevocationList{
Revoked: make(map[KeyID]time.Time),
LastUpdated: now,
ExpiresAt: now.Add(15 * 365 * 24 * time.Hour), // 15 years in the future
}
rlData, err := json.Marshal(rl)
require.NoError(t, err)
// Sign it
sig, err := signRevocationList(rootKey, rl)
require.NoError(t, err)
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
assert.Error(t, err)
assert.Contains(t, err.Error(), "too far in the future")
}
// Test CreateRevocationList
func TestCreateRevocationList_Valid(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
assert.NotEmpty(t, rlData)
assert.NotEmpty(t, sigData)
// Verify it can be parsed
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
assert.Empty(t, rl.Revoked)
assert.False(t, rl.LastUpdated.IsZero())
// Verify signature can be parsed
sig, err := ParseSignature(sigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
}
// Test ExtendRevocationList
func TestExtendRevocationList_AddKey(t *testing.T) {
// Generate root key
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create empty revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
assert.Empty(t, rl.Revoked)
// Generate a key to revoke
revokedPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
revokedKeyID := computeKeyID(revokedPub)
// Extend the revocation list
newRLData, newSigData, err := ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration)
require.NoError(t, err)
// Verify the new list
newRL, err := ParseRevocationList(newRLData)
require.NoError(t, err)
assert.Len(t, newRL.Revoked, 1)
assert.Contains(t, newRL.Revoked, revokedKeyID)
// Verify signature
sig, err := ParseSignature(newSigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
}
func TestExtendRevocationList_MultipleKeys(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create empty revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
// Add first key
key1Pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
key1ID := computeKeyID(key1Pub)
rlData, _, err = ExtendRevocationList(rootKey, *rl, key1ID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 1)
// Add second key
key2Pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
key2ID := computeKeyID(key2Pub)
rlData, _, err = ExtendRevocationList(rootKey, *rl, key2ID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 2)
assert.Contains(t, rl.Revoked, key1ID)
assert.Contains(t, rl.Revoked, key2ID)
}
func TestExtendRevocationList_DuplicateKey(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create empty revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
// Add a key
keyPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(keyPub)
rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
firstRevocationTime := rl.Revoked[keyID]
// Wait a bit
time.Sleep(10 * time.Millisecond)
// Add the same key again
rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 1)
// The revocation time should be updated
secondRevocationTime := rl.Revoked[keyID]
assert.True(t, secondRevocationTime.After(firstRevocationTime) || secondRevocationTime.Equal(firstRevocationTime))
}
func TestExtendRevocationList_UpdatesLastUpdated(t *testing.T) {
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Create revocation list
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err := ParseRevocationList(rlData)
require.NoError(t, err)
firstLastUpdated := rl.LastUpdated
// Wait a bit
time.Sleep(10 * time.Millisecond)
// Extend list
keyPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := computeKeyID(keyPub)
rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
require.NoError(t, err)
rl, err = ParseRevocationList(rlData)
require.NoError(t, err)
// LastUpdated should be updated
assert.True(t, rl.LastUpdated.After(firstLastUpdated))
}
// Integration test
func TestRevocationList_FullWorkflow(t *testing.T) {
// Create root key
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
rootKey := RootKey{
PrivateKey{
Key: rootPriv,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
rootKeys := []PublicKey{
{
Key: rootPub,
Metadata: KeyMetadata{
ID: computeKeyID(rootPub),
CreatedAt: time.Now().UTC(),
},
},
}
// Step 1: Create empty revocation list
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
// Step 2: Validate it
sig, err := ParseSignature(sigData)
require.NoError(t, err)
rl, err := ValidateRevocationList(rootKeys, rlData, *sig)
require.NoError(t, err)
assert.Empty(t, rl.Revoked)
// Step 3: Revoke a key
revokedPub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
revokedKeyID := computeKeyID(revokedPub)
rlData, sigData, err = ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration)
require.NoError(t, err)
// Step 4: Validate the extended list
sig, err = ParseSignature(sigData)
require.NoError(t, err)
rl, err = ValidateRevocationList(rootKeys, rlData, *sig)
require.NoError(t, err)
assert.Len(t, rl.Revoked, 1)
assert.Contains(t, rl.Revoked, revokedKeyID)
// Step 5: Verify the revocation time is reasonable
revTime := rl.Revoked[revokedKeyID]
now := time.Now().UTC()
assert.True(t, revTime.Before(now) || revTime.Equal(now))
assert.True(t, now.Sub(revTime) < time.Minute)
}

View File

@@ -0,0 +1,120 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
"encoding/json"
"encoding/pem"
"fmt"
"time"
)
const (
tagRootPrivate = "ROOT PRIVATE KEY"
tagRootPublic = "ROOT PUBLIC KEY"
)
// RootKey is a root Key used to sign signing keys
type RootKey struct {
PrivateKey
}
func (k RootKey) String() string {
return fmt.Sprintf(
"RootKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
k.Metadata.ID,
k.Metadata.CreatedAt.Format(time.RFC3339),
k.Metadata.ExpiresAt.Format(time.RFC3339),
)
}
func ParseRootKey(privKeyPEM []byte) (*RootKey, error) {
pk, err := parsePrivateKey(privKeyPEM, tagRootPrivate)
if err != nil {
return nil, fmt.Errorf("failed to parse root Key: %w", err)
}
return &RootKey{pk}, nil
}
// ParseRootPublicKey parses a root public key from PEM format
func ParseRootPublicKey(pubKeyPEM []byte) (PublicKey, error) {
pk, _, err := parsePublicKey(pubKeyPEM, tagRootPublic)
if err != nil {
return PublicKey{}, fmt.Errorf("failed to parse root public key: %w", err)
}
return pk, nil
}
// GenerateRootKey generates a new root Key pair with Metadata
func GenerateRootKey(expiration time.Duration) (*RootKey, []byte, []byte, error) {
now := time.Now()
expirationTime := now.Add(expiration)
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, nil, err
}
metadata := KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: now.UTC(),
ExpiresAt: expirationTime.UTC(),
}
rk := &RootKey{
PrivateKey{
Key: priv,
Metadata: metadata,
},
}
// Marshal PrivateKey struct to JSON
privJSON, err := json.Marshal(rk.PrivateKey)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
}
// Marshal PublicKey struct to JSON
pubKey := PublicKey{
Key: pub,
Metadata: metadata,
}
pubJSON, err := json.Marshal(pubKey)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
}
// Encode to PEM with metadata embedded in bytes
privPEM := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: privJSON,
})
pubPEM := pem.EncodeToMemory(&pem.Block{
Type: tagRootPublic,
Bytes: pubJSON,
})
return rk, privPEM, pubPEM, nil
}
func SignArtifactKey(rootKey RootKey, data []byte) ([]byte, error) {
timestamp := time.Now().UTC()
// This ensures the timestamp is cryptographically bound to the signature
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
sig := ed25519.Sign(rootKey.Key, msg)
// Create signature bundle with timestamp and Metadata
bundle := Signature{
Signature: sig,
Timestamp: timestamp,
KeyID: rootKey.Metadata.ID,
Algorithm: "ed25519",
HashAlgo: "sha512",
}
return json.Marshal(bundle)
}

View File

@@ -0,0 +1,476 @@
package reposign
import (
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
"encoding/json"
"encoding/pem"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test RootKey.String()
func TestRootKey_String(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
expiresAt := time.Date(2034, 1, 15, 10, 30, 0, 0, time.UTC)
rk := RootKey{
PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: createdAt,
ExpiresAt: expiresAt,
},
},
}
str := rk.String()
assert.Contains(t, str, "RootKey")
assert.Contains(t, str, computeKeyID(pub).String())
assert.Contains(t, str, "2024-01-15")
assert.Contains(t, str, "2034-01-15")
}
func TestRootKey_String_NoExpiration(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
rk := RootKey{
PrivateKey{
Key: priv,
Metadata: KeyMetadata{
ID: computeKeyID(pub),
CreatedAt: createdAt,
ExpiresAt: time.Time{}, // No expiration
},
},
}
str := rk.String()
assert.Contains(t, str, "RootKey")
assert.Contains(t, str, "0001-01-01") // Zero time format
}
// Test GenerateRootKey
func TestGenerateRootKey_Valid(t *testing.T) {
expiration := 10 * 365 * 24 * time.Hour // 10 years
rk, privPEM, pubPEM, err := GenerateRootKey(expiration)
require.NoError(t, err)
assert.NotNil(t, rk)
assert.NotEmpty(t, privPEM)
assert.NotEmpty(t, pubPEM)
// Verify the key has correct metadata
assert.False(t, rk.Metadata.CreatedAt.IsZero())
assert.False(t, rk.Metadata.ExpiresAt.IsZero())
assert.True(t, rk.Metadata.ExpiresAt.After(rk.Metadata.CreatedAt))
// Verify expiration is approximately correct
expectedExpiration := time.Now().Add(expiration)
timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration)
assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute)
}
func TestGenerateRootKey_ShortExpiration(t *testing.T) {
expiration := 24 * time.Hour // 1 day
rk, _, _, err := GenerateRootKey(expiration)
require.NoError(t, err)
assert.NotNil(t, rk)
// Verify expiration
expectedExpiration := time.Now().Add(expiration)
timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration)
assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute)
}
func TestGenerateRootKey_ZeroExpiration(t *testing.T) {
rk, _, _, err := GenerateRootKey(0)
require.NoError(t, err)
assert.NotNil(t, rk)
// With zero expiration, ExpiresAt should be equal to CreatedAt
assert.Equal(t, rk.Metadata.CreatedAt, rk.Metadata.ExpiresAt)
}
func TestGenerateRootKey_PEMFormat(t *testing.T) {
rk, privPEM, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Verify private key PEM
privBlock, _ := pem.Decode(privPEM)
require.NotNil(t, privBlock)
assert.Equal(t, tagRootPrivate, privBlock.Type)
var privKey PrivateKey
err = json.Unmarshal(privBlock.Bytes, &privKey)
require.NoError(t, err)
assert.Equal(t, rk.Key, privKey.Key)
// Verify public key PEM
pubBlock, _ := pem.Decode(pubPEM)
require.NotNil(t, pubBlock)
assert.Equal(t, tagRootPublic, pubBlock.Type)
var pubKey PublicKey
err = json.Unmarshal(pubBlock.Bytes, &pubKey)
require.NoError(t, err)
assert.Equal(t, rk.Metadata.ID, pubKey.Metadata.ID)
}
func TestGenerateRootKey_KeySize(t *testing.T) {
rk, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Ed25519 private key should be 64 bytes
assert.Equal(t, ed25519.PrivateKeySize, len(rk.Key))
// Ed25519 public key should be 32 bytes
pubKey := rk.Key.Public().(ed25519.PublicKey)
assert.Equal(t, ed25519.PublicKeySize, len(pubKey))
}
func TestGenerateRootKey_UniqueKeys(t *testing.T) {
rk1, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rk2, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Different keys should have different IDs
assert.NotEqual(t, rk1.Metadata.ID, rk2.Metadata.ID)
assert.NotEqual(t, rk1.Key, rk2.Key)
}
// Test ParseRootKey
func TestParseRootKey_Valid(t *testing.T) {
original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
parsed, err := ParseRootKey(privPEM)
require.NoError(t, err)
assert.NotNil(t, parsed)
// Verify the parsed key matches the original
assert.Equal(t, original.Key, parsed.Key)
assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID)
assert.Equal(t, original.Metadata.CreatedAt.Unix(), parsed.Metadata.CreatedAt.Unix())
assert.Equal(t, original.Metadata.ExpiresAt.Unix(), parsed.Metadata.ExpiresAt.Unix())
}
func TestParseRootKey_InvalidPEM(t *testing.T) {
_, err := ParseRootKey([]byte("not a valid PEM"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse")
}
func TestParseRootKey_EmptyData(t *testing.T) {
_, err := ParseRootKey([]byte{})
assert.Error(t, err)
}
func TestParseRootKey_WrongType(t *testing.T) {
// Generate an artifact key instead of root key
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
artifactKey, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
// Try to parse artifact key as root key
_, err = ParseRootKey(privPEM)
assert.Error(t, err)
assert.Contains(t, err.Error(), "PEM type")
// Just to use artifactKey to avoid unused variable warning
_ = artifactKey
}
func TestParseRootKey_CorruptedJSON(t *testing.T) {
// Create PEM with corrupted JSON
corruptedPEM := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: []byte("corrupted json data"),
})
_, err := ParseRootKey(corruptedPEM)
assert.Error(t, err)
}
func TestParseRootKey_InvalidKeySize(t *testing.T) {
// Create a key with invalid size
invalidKey := PrivateKey{
Key: []byte{0x01, 0x02, 0x03}, // Too short
Metadata: KeyMetadata{
ID: KeyID{},
CreatedAt: time.Now().UTC(),
},
}
privJSON, err := json.Marshal(invalidKey)
require.NoError(t, err)
invalidPEM := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: privJSON,
})
_, err = ParseRootKey(invalidPEM)
assert.Error(t, err)
assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
}
func TestParseRootKey_Roundtrip(t *testing.T) {
// Generate a key
original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Parse it
parsed, err := ParseRootKey(privPEM)
require.NoError(t, err)
// Generate PEM again from parsed key
privJSON2, err := json.Marshal(parsed.PrivateKey)
require.NoError(t, err)
privPEM2 := pem.EncodeToMemory(&pem.Block{
Type: tagRootPrivate,
Bytes: privJSON2,
})
// Parse again
parsed2, err := ParseRootKey(privPEM2)
require.NoError(t, err)
// Should still match original
assert.Equal(t, original.Key, parsed2.Key)
assert.Equal(t, original.Metadata.ID, parsed2.Metadata.ID)
}
// Test SignArtifactKey
func TestSignArtifactKey_Valid(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
data := []byte("test data to sign")
sigData, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
assert.NotEmpty(t, sigData)
// Parse and verify signature
sig, err := ParseSignature(sigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
assert.Equal(t, rootKey.Metadata.ID, sig.KeyID)
assert.Equal(t, "ed25519", sig.Algorithm)
assert.Equal(t, "sha512", sig.HashAlgo)
assert.False(t, sig.Timestamp.IsZero())
}
func TestSignArtifactKey_EmptyData(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
sigData, err := SignArtifactKey(*rootKey, []byte{})
require.NoError(t, err)
assert.NotEmpty(t, sigData)
// Should still be able to parse
sig, err := ParseSignature(sigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
}
func TestSignArtifactKey_Verify(t *testing.T) {
rootKey, _, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Parse public key
pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic)
require.NoError(t, err)
// Sign some data
data := []byte("test data for verification")
sigData, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
// Parse signature
sig, err := ParseSignature(sigData)
require.NoError(t, err)
// Reconstruct message
msg := make([]byte, 0, len(data)+8)
msg = append(msg, data...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix()))
// Verify signature
valid := ed25519.Verify(pubKey.Key, msg, sig.Signature)
assert.True(t, valid)
}
func TestSignArtifactKey_DifferentData(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
data1 := []byte("data1")
data2 := []byte("data2")
sig1, err := SignArtifactKey(*rootKey, data1)
require.NoError(t, err)
sig2, err := SignArtifactKey(*rootKey, data2)
require.NoError(t, err)
// Different data should produce different signatures
assert.NotEqual(t, sig1, sig2)
}
func TestSignArtifactKey_MultipleSignatures(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
data := []byte("test data")
// Sign twice with a small delay
sig1, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
time.Sleep(10 * time.Millisecond)
sig2, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
// Signatures should be different due to different timestamps
assert.NotEqual(t, sig1, sig2)
// Parse both signatures
parsed1, err := ParseSignature(sig1)
require.NoError(t, err)
parsed2, err := ParseSignature(sig2)
require.NoError(t, err)
// Timestamps should be different
assert.True(t, parsed2.Timestamp.After(parsed1.Timestamp))
}
func TestSignArtifactKey_LargeData(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
// Create 1MB of data
largeData := make([]byte, 1024*1024)
for i := range largeData {
largeData[i] = byte(i % 256)
}
sigData, err := SignArtifactKey(*rootKey, largeData)
require.NoError(t, err)
assert.NotEmpty(t, sigData)
// Verify signature can be parsed
sig, err := ParseSignature(sigData)
require.NoError(t, err)
assert.NotEmpty(t, sig.Signature)
}
func TestSignArtifactKey_TimestampInSignature(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
beforeSign := time.Now().UTC()
data := []byte("test data")
sigData, err := SignArtifactKey(*rootKey, data)
require.NoError(t, err)
afterSign := time.Now().UTC()
sig, err := ParseSignature(sigData)
require.NoError(t, err)
// Timestamp should be between before and after
assert.True(t, sig.Timestamp.After(beforeSign.Add(-time.Second)))
assert.True(t, sig.Timestamp.Before(afterSign.Add(time.Second)))
}
// Integration test
func TestRootKey_FullWorkflow(t *testing.T) {
// Step 1: Generate root key
rootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
assert.NotNil(t, rootKey)
assert.NotEmpty(t, privPEM)
assert.NotEmpty(t, pubPEM)
// Step 2: Parse the private key back
parsedRootKey, err := ParseRootKey(privPEM)
require.NoError(t, err)
assert.Equal(t, rootKey.Key, parsedRootKey.Key)
assert.Equal(t, rootKey.Metadata.ID, parsedRootKey.Metadata.ID)
// Step 3: Generate an artifact key using root key
artifactKey, _, artifactPubPEM, artifactSig, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
assert.NotNil(t, artifactKey)
// Step 4: Verify the artifact key signature
pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic)
require.NoError(t, err)
sig, err := ParseSignature(artifactSig)
require.NoError(t, err)
artifactPubKey, _, err := parsePublicKey(artifactPubPEM, tagArtifactPublic)
require.NoError(t, err)
// Reconstruct message - SignArtifactKey signs the PEM, not the JSON
msg := make([]byte, 0, len(artifactPubPEM)+8)
msg = append(msg, artifactPubPEM...)
msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix()))
// Verify with root public key
valid := ed25519.Verify(pubKey.Key, msg, sig.Signature)
assert.True(t, valid, "Artifact key signature should be valid")
// Step 5: Use artifact key to sign data
testData := []byte("This is test artifact data")
dataSig, err := SignData(*artifactKey, testData)
require.NoError(t, err)
assert.NotEmpty(t, dataSig)
// Step 6: Verify the artifact data signature
dataSigParsed, err := ParseSignature(dataSig)
require.NoError(t, err)
err = ValidateArtifact([]PublicKey{artifactPubKey}, testData, *dataSigParsed)
assert.NoError(t, err, "Artifact data signature should be valid")
}
func TestRootKey_ExpiredKeyWorkflow(t *testing.T) {
// Generate a root key that expires very soon
rootKey, _, _, err := GenerateRootKey(1 * time.Millisecond)
require.NoError(t, err)
// Wait for expiration
time.Sleep(10 * time.Millisecond)
// Try to generate artifact key with expired root key
_, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour)
assert.Error(t, err)
assert.Contains(t, err.Error(), "expired")
}

View File

@@ -0,0 +1,24 @@
package reposign
import (
"encoding/json"
"time"
)
// Signature contains a signature with associated Metadata
type Signature struct {
Signature []byte `json:"signature"`
Timestamp time.Time `json:"timestamp"`
KeyID KeyID `json:"key_id"`
Algorithm string `json:"algorithm"` // "ed25519"
HashAlgo string `json:"hash_algo"` // "blake2s" or sha512
}
func ParseSignature(data []byte) (*Signature, error) {
var signature Signature
if err := json.Unmarshal(data, &signature); err != nil {
return nil, err
}
return &signature, nil
}

View File

@@ -0,0 +1,277 @@
package reposign
import (
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseSignature_Valid(t *testing.T) {
timestamp := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
keyID, err := ParseKeyID("0123456789abcdef")
require.NoError(t, err)
signatureData := []byte{0x01, 0x02, 0x03, 0x04}
jsonData, err := json.Marshal(Signature{
Signature: signatureData,
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
})
require.NoError(t, err)
sig, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.NotNil(t, sig)
assert.Equal(t, signatureData, sig.Signature)
assert.Equal(t, timestamp.Unix(), sig.Timestamp.Unix())
assert.Equal(t, keyID, sig.KeyID)
assert.Equal(t, "ed25519", sig.Algorithm)
assert.Equal(t, "blake2s", sig.HashAlgo)
}
func TestParseSignature_InvalidJSON(t *testing.T) {
invalidJSON := []byte(`{invalid json}`)
sig, err := ParseSignature(invalidJSON)
assert.Error(t, err)
assert.Nil(t, sig)
}
func TestParseSignature_EmptyData(t *testing.T) {
emptyJSON := []byte(`{}`)
sig, err := ParseSignature(emptyJSON)
require.NoError(t, err)
assert.NotNil(t, sig)
assert.Empty(t, sig.Signature)
assert.True(t, sig.Timestamp.IsZero())
assert.Equal(t, KeyID{}, sig.KeyID)
assert.Empty(t, sig.Algorithm)
assert.Empty(t, sig.HashAlgo)
}
func TestParseSignature_MissingFields(t *testing.T) {
// JSON with only some fields
partialJSON := []byte(`{
"signature": "AQIDBA==",
"algorithm": "ed25519"
}`)
sig, err := ParseSignature(partialJSON)
require.NoError(t, err)
assert.NotNil(t, sig)
assert.NotEmpty(t, sig.Signature)
assert.Equal(t, "ed25519", sig.Algorithm)
assert.True(t, sig.Timestamp.IsZero())
}
func TestSignature_MarshalUnmarshal_Roundtrip(t *testing.T) {
timestamp := time.Date(2024, 6, 20, 14, 45, 30, 0, time.UTC)
keyID, err := ParseKeyID("fedcba9876543210")
require.NoError(t, err)
original := Signature{
Signature: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe},
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "sha512",
}
// Marshal
jsonData, err := json.Marshal(original)
require.NoError(t, err)
// Unmarshal
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
// Verify
assert.Equal(t, original.Signature, parsed.Signature)
assert.Equal(t, original.Timestamp.Unix(), parsed.Timestamp.Unix())
assert.Equal(t, original.KeyID, parsed.KeyID)
assert.Equal(t, original.Algorithm, parsed.Algorithm)
assert.Equal(t, original.HashAlgo, parsed.HashAlgo)
}
func TestSignature_NilSignatureBytes(t *testing.T) {
timestamp := time.Now().UTC()
keyID, err := ParseKeyID("0011223344556677")
require.NoError(t, err)
sig := Signature{
Signature: nil,
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.Nil(t, parsed.Signature)
}
func TestSignature_LargeSignature(t *testing.T) {
timestamp := time.Now().UTC()
keyID, err := ParseKeyID("aabbccddeeff0011")
require.NoError(t, err)
// Create a large signature (64 bytes for ed25519)
largeSignature := make([]byte, 64)
for i := range largeSignature {
largeSignature[i] = byte(i)
}
sig := Signature{
Signature: largeSignature,
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.Equal(t, largeSignature, parsed.Signature)
}
func TestSignature_WithDifferentHashAlgorithms(t *testing.T) {
tests := []struct {
name string
hashAlgo string
}{
{"blake2s", "blake2s"},
{"sha512", "sha512"},
{"sha256", "sha256"},
{"empty", ""},
}
keyID, err := ParseKeyID("1122334455667788")
require.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sig := Signature{
Signature: []byte{0x01, 0x02},
Timestamp: time.Now().UTC(),
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: tt.hashAlgo,
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.Equal(t, tt.hashAlgo, parsed.HashAlgo)
})
}
}
func TestSignature_TimestampPrecision(t *testing.T) {
// Test that timestamp preserves precision through JSON marshaling
timestamp := time.Date(2024, 3, 15, 10, 30, 45, 123456789, time.UTC)
keyID, err := ParseKeyID("8877665544332211")
require.NoError(t, err)
sig := Signature{
Signature: []byte{0xaa, 0xbb},
Timestamp: timestamp,
KeyID: keyID,
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
// JSON timestamps typically have second or millisecond precision
// so we check that at least seconds match
assert.Equal(t, timestamp.Unix(), parsed.Timestamp.Unix())
}
func TestParseSignature_MalformedKeyID(t *testing.T) {
// Test with a malformed KeyID field
malformedJSON := []byte(`{
"signature": "AQID",
"timestamp": "2024-01-15T10:30:00Z",
"key_id": "invalid_keyid_format",
"algorithm": "ed25519",
"hash_algo": "blake2s"
}`)
// This should fail since "invalid_keyid_format" is not a valid KeyID
sig, err := ParseSignature(malformedJSON)
assert.Error(t, err)
assert.Nil(t, sig)
}
func TestParseSignature_InvalidTimestamp(t *testing.T) {
// Test with an invalid timestamp format
invalidTimestampJSON := []byte(`{
"signature": "AQID",
"timestamp": "not-a-timestamp",
"key_id": "0123456789abcdef",
"algorithm": "ed25519",
"hash_algo": "blake2s"
}`)
sig, err := ParseSignature(invalidTimestampJSON)
assert.Error(t, err)
assert.Nil(t, sig)
}
func TestSignature_ZeroKeyID(t *testing.T) {
// Test with a zero KeyID
sig := Signature{
Signature: []byte{0x01, 0x02, 0x03},
Timestamp: time.Now().UTC(),
KeyID: KeyID{},
Algorithm: "ed25519",
HashAlgo: "blake2s",
}
jsonData, err := json.Marshal(sig)
require.NoError(t, err)
parsed, err := ParseSignature(jsonData)
require.NoError(t, err)
assert.Equal(t, KeyID{}, parsed.KeyID)
}
func TestParseSignature_ExtraFields(t *testing.T) {
// JSON with extra fields that should be ignored
jsonWithExtra := []byte(`{
"signature": "AQIDBA==",
"timestamp": "2024-01-15T10:30:00Z",
"key_id": "0123456789abcdef",
"algorithm": "ed25519",
"hash_algo": "blake2s",
"extra_field": "should be ignored",
"another_extra": 12345
}`)
sig, err := ParseSignature(jsonWithExtra)
require.NoError(t, err)
assert.NotNil(t, sig)
assert.NotEmpty(t, sig.Signature)
assert.Equal(t, "ed25519", sig.Algorithm)
assert.Equal(t, "blake2s", sig.HashAlgo)
}

View File

@@ -0,0 +1,187 @@
package reposign
import (
"context"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
)
const (
artifactPubKeysFileName = "artifact-key-pub.pem"
artifactPubKeysSigFileName = "artifact-key-pub.pem.sig"
revocationFileName = "revocation-list.json"
revocationSignFileName = "revocation-list.json.sig"
keySizeLimit = 5 * 1024 * 1024 //5MB
signatureLimit = 1024
revocationLimit = 10 * 1024 * 1024
)
type ArtifactVerify struct {
rootKeys []PublicKey
keysBaseURL *url.URL
revocationList *RevocationList
}
func NewArtifactVerify(keysBaseURL string) (*ArtifactVerify, error) {
allKeys, err := loadEmbeddedPublicKeys()
if err != nil {
return nil, err
}
return newArtifactVerify(keysBaseURL, allKeys)
}
func newArtifactVerify(keysBaseURL string, allKeys []PublicKey) (*ArtifactVerify, error) {
ku, err := url.Parse(keysBaseURL)
if err != nil {
return nil, fmt.Errorf("invalid keys base URL %q: %v", keysBaseURL, err)
}
a := &ArtifactVerify{
rootKeys: allKeys,
keysBaseURL: ku,
}
return a, nil
}
func (a *ArtifactVerify) Verify(ctx context.Context, version string, artifactFile string) error {
version = strings.TrimPrefix(version, "v")
revocationList, err := a.loadRevocationList(ctx)
if err != nil {
return fmt.Errorf("failed to load revocation list: %v", err)
}
a.revocationList = revocationList
artifactPubKeys, err := a.loadArtifactKeys(ctx)
if err != nil {
return fmt.Errorf("failed to load artifact keys: %v", err)
}
signature, err := a.loadArtifactSignature(ctx, version, artifactFile)
if err != nil {
return fmt.Errorf("failed to download signature file for: %s, %v", filepath.Base(artifactFile), err)
}
artifactData, err := os.ReadFile(artifactFile)
if err != nil {
log.Errorf("failed to read artifact file: %v", err)
return fmt.Errorf("failed to read artifact file: %w", err)
}
if err := ValidateArtifact(artifactPubKeys, artifactData, *signature); err != nil {
return fmt.Errorf("failed to validate artifact: %v", err)
}
return nil
}
func (a *ArtifactVerify) loadRevocationList(ctx context.Context) (*RevocationList, error) {
downloadURL := a.keysBaseURL.JoinPath("keys", revocationFileName).String()
data, err := downloader.DownloadToMemory(ctx, downloadURL, revocationLimit)
if err != nil {
log.Debugf("failed to download revocation list '%s': %s", downloadURL, err)
return nil, err
}
downloadURL = a.keysBaseURL.JoinPath("keys", revocationSignFileName).String()
sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
if err != nil {
log.Debugf("failed to download revocation list '%s': %s", downloadURL, err)
return nil, err
}
signature, err := ParseSignature(sigData)
if err != nil {
log.Debugf("failed to parse revocation list signature: %s", err)
return nil, err
}
return ValidateRevocationList(a.rootKeys, data, *signature)
}
func (a *ArtifactVerify) loadArtifactKeys(ctx context.Context) ([]PublicKey, error) {
downloadURL := a.keysBaseURL.JoinPath("keys", artifactPubKeysFileName).String()
log.Debugf("starting downloading artifact keys from: %s", downloadURL)
data, err := downloader.DownloadToMemory(ctx, downloadURL, keySizeLimit)
if err != nil {
log.Debugf("failed to download artifact keys: %s", err)
return nil, err
}
downloadURL = a.keysBaseURL.JoinPath("keys", artifactPubKeysSigFileName).String()
log.Debugf("start downloading signature of artifact pub key from: %s", downloadURL)
sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
if err != nil {
log.Debugf("failed to download signature of public keys: %s", err)
return nil, err
}
signature, err := ParseSignature(sigData)
if err != nil {
log.Debugf("failed to parse signature of public keys: %s", err)
return nil, err
}
return ValidateArtifactKeys(a.rootKeys, data, *signature, a.revocationList)
}
func (a *ArtifactVerify) loadArtifactSignature(ctx context.Context, version string, artifactFile string) (*Signature, error) {
artifactFile = filepath.Base(artifactFile)
downloadURL := a.keysBaseURL.JoinPath("tag", "v"+version, artifactFile+".sig").String()
data, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
if err != nil {
log.Debugf("failed to download artifact signature: %s", err)
return nil, err
}
signature, err := ParseSignature(data)
if err != nil {
log.Debugf("failed to parse artifact signature: %s", err)
return nil, err
}
return signature, nil
}
func loadEmbeddedPublicKeys() ([]PublicKey, error) {
files, err := embeddedCerts.ReadDir(embeddedCertsDir)
if err != nil {
return nil, fmt.Errorf("failed to read embedded certs: %w", err)
}
var allKeys []PublicKey
for _, file := range files {
if file.IsDir() {
continue
}
data, err := embeddedCerts.ReadFile(embeddedCertsDir + "/" + file.Name())
if err != nil {
return nil, fmt.Errorf("failed to read cert file %s: %w", file.Name(), err)
}
keys, err := parsePublicKeyBundle(data, tagRootPublic)
if err != nil {
return nil, fmt.Errorf("failed to parse cert %s: %w", file.Name(), err)
}
allKeys = append(allKeys, keys...)
}
if len(allKeys) == 0 {
return nil, fmt.Errorf("no valid public keys found in embedded certs")
}
return allKeys, nil
}

View File

@@ -0,0 +1,528 @@
package reposign
import (
"context"
"crypto/ed25519"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test ArtifactVerify construction
func TestArtifactVerify_Construction(t *testing.T) {
// Generate test root key
rootKey, _, rootPubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey, _, err := parsePublicKey(rootPubPEM, tagRootPublic)
require.NoError(t, err)
keysBaseURL := "http://localhost:8080/artifact-signatures"
av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey})
require.NoError(t, err)
assert.NotNil(t, av)
assert.NotEmpty(t, av.rootKeys)
assert.Equal(t, keysBaseURL, av.keysBaseURL.String())
// Verify root key structure
assert.NotEmpty(t, av.rootKeys[0].Key)
assert.Equal(t, rootKey.Metadata.ID, av.rootKeys[0].Metadata.ID)
assert.False(t, av.rootKeys[0].Metadata.CreatedAt.IsZero())
}
func TestArtifactVerify_MultipleRootKeys(t *testing.T) {
// Generate multiple test root keys
rootKey1, _, rootPubPEM1, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey1, _, err := parsePublicKey(rootPubPEM1, tagRootPublic)
require.NoError(t, err)
rootKey2, _, rootPubPEM2, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey2, _, err := parsePublicKey(rootPubPEM2, tagRootPublic)
require.NoError(t, err)
keysBaseURL := "http://localhost:8080/artifact-signatures"
av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey1, rootPubKey2})
assert.NoError(t, err)
assert.Len(t, av.rootKeys, 2)
assert.NotEqual(t, rootKey1.Metadata.ID, rootKey2.Metadata.ID)
}
// Test Verify workflow with mock HTTP server
func TestArtifactVerify_FullWorkflow(t *testing.T) {
// Create temporary test directory
tempDir := t.TempDir()
// Step 1: Generate root key
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
// Step 2: Generate artifact key
artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
require.NoError(t, err)
// Step 3: Create revocation list
revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
// Step 4: Bundle artifact keys
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
require.NoError(t, err)
// Step 5: Create test artifact
artifactPath := filepath.Join(tempDir, "test-artifact.bin")
artifactData := []byte("This is test artifact data for verification")
err = os.WriteFile(artifactPath, artifactData, 0644)
require.NoError(t, err)
// Step 6: Sign artifact
artifactSigData, err := SignData(*artifactKey, artifactData)
require.NoError(t, err)
// Step 7: Setup mock HTTP server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifacts/v1.0.0/test-artifact.bin":
_, _ = w.Write(artifactData)
case "/artifact-signatures/tag/v1.0.0/test-artifact.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
// Step 8: Create ArtifactVerify with test root key
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
// Step 9: Verify artifact
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
assert.NoError(t, err)
}
func TestArtifactVerify_InvalidRevocationList(t *testing.T) {
tempDir := t.TempDir()
artifactPath := filepath.Join(tempDir, "test.bin")
err := os.WriteFile(artifactPath, []byte("test"), 0644)
require.NoError(t, err)
// Setup server with invalid revocation list
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write([]byte("invalid data"))
default:
http.NotFound(w, r)
}
}))
defer server.Close()
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to load revocation list")
}
func TestArtifactVerify_MissingArtifactFile(t *testing.T) {
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
// Create revocation list
revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
require.NoError(t, err)
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
require.NoError(t, err)
// Create signature for non-existent file
testData := []byte("test")
artifactSigData, err := SignData(*artifactKey, testData)
require.NoError(t, err)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifact-signatures/tag/v1.0.0/missing.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", "file.bin")
assert.Error(t, err)
}
func TestArtifactVerify_ServerUnavailable(t *testing.T) {
tempDir := t.TempDir()
artifactPath := filepath.Join(tempDir, "test.bin")
err := os.WriteFile(artifactPath, []byte("test"), 0644)
require.NoError(t, err)
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
// Use URL that doesn't exist
av, err := newArtifactVerify("http://localhost:19999/keys", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err = av.Verify(ctx, "1.0.0", artifactPath)
assert.Error(t, err)
}
func TestArtifactVerify_ContextCancellation(t *testing.T) {
tempDir := t.TempDir()
artifactPath := filepath.Join(tempDir, "test.bin")
err := os.WriteFile(artifactPath, []byte("test"), 0644)
require.NoError(t, err)
// Create a server that delays response
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(500 * time.Millisecond)
_, _ = w.Write([]byte("data"))
}))
defer server.Close()
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
require.NoError(t, err)
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL, []PublicKey{rootPubKey})
require.NoError(t, err)
// Create context that cancels quickly
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err = av.Verify(ctx, "1.0.0", artifactPath)
assert.Error(t, err)
}
func TestArtifactVerify_WithRevocation(t *testing.T) {
tempDir := t.TempDir()
// Generate root key
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
// Generate two artifact keys
artifactKey1, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1)
require.NoError(t, err)
_, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2)
require.NoError(t, err)
// Create revocation list with first key revoked
emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
parsedRevocation, err := ParseRevocationList(emptyRevocation)
require.NoError(t, err)
revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration)
require.NoError(t, err)
// Bundle both keys
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2})
require.NoError(t, err)
// Create artifact signed by revoked key
artifactPath := filepath.Join(tempDir, "test.bin")
artifactData := []byte("test data")
err = os.WriteFile(artifactPath, artifactData, 0644)
require.NoError(t, err)
artifactSigData, err := SignData(*artifactKey1, artifactData)
require.NoError(t, err)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
// Should fail because the signing key is revoked
assert.Error(t, err)
assert.Contains(t, err.Error(), "no signing Key found")
}
func TestArtifactVerify_ValidWithSecondKey(t *testing.T) {
tempDir := t.TempDir()
// Generate root key
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
// Generate two artifact keys
_, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1)
require.NoError(t, err)
artifactKey2, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2)
require.NoError(t, err)
// Create revocation list with first key revoked
emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
parsedRevocation, err := ParseRevocationList(emptyRevocation)
require.NoError(t, err)
revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration)
require.NoError(t, err)
// Bundle both keys
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2})
require.NoError(t, err)
// Create artifact signed by second key (not revoked)
artifactPath := filepath.Join(tempDir, "test.bin")
artifactData := []byte("test data")
err = os.WriteFile(artifactPath, artifactData, 0644)
require.NoError(t, err)
artifactSigData, err := SignData(*artifactKey2, artifactData)
require.NoError(t, err)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
// Should succeed because second key is not revoked
assert.NoError(t, err)
}
func TestArtifactVerify_TamperedArtifact(t *testing.T) {
tempDir := t.TempDir()
// Generate root key and artifact key
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
require.NoError(t, err)
artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
require.NoError(t, err)
artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
require.NoError(t, err)
// Create revocation list
revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
require.NoError(t, err)
// Bundle keys
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
require.NoError(t, err)
// Sign original data
originalData := []byte("original data")
artifactSigData, err := SignData(*artifactKey, originalData)
require.NoError(t, err)
// Write tampered data to file
artifactPath := filepath.Join(tempDir, "test.bin")
tamperedData := []byte("tampered data")
err = os.WriteFile(artifactPath, tamperedData, 0644)
require.NoError(t, err)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/artifact-signatures/keys/" + revocationFileName:
_, _ = w.Write(revocationData)
case "/artifact-signatures/keys/" + revocationSignFileName:
_, _ = w.Write(revocationSig)
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
_, _ = w.Write(artifactKeysBundle)
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
_, _ = w.Write(artifactKeysSig)
case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
_, _ = w.Write(artifactSigData)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
rootPubKey := PublicKey{
Key: rootKey.Key.Public().(ed25519.PublicKey),
Metadata: rootKey.Metadata,
}
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
require.NoError(t, err)
ctx := context.Background()
err = av.Verify(ctx, "1.0.0", artifactPath)
// Should fail because artifact was tampered
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to validate artifact")
}
// Test URL validation
func TestArtifactVerify_URLParsing(t *testing.T) {
tests := []struct {
name string
keysBaseURL string
expectError bool
}{
{
name: "Valid HTTP URL",
keysBaseURL: "http://example.com/artifact-signatures",
expectError: false,
},
{
name: "Valid HTTPS URL",
keysBaseURL: "https://example.com/artifact-signatures",
expectError: false,
},
{
name: "URL with port",
keysBaseURL: "http://localhost:8080/artifact-signatures",
expectError: false,
},
{
name: "Invalid URL",
keysBaseURL: "://invalid",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := newArtifactVerify(tt.keysBaseURL, nil)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,11 @@
package updatemanager
import v "github.com/hashicorp/go-version"
type UpdateInterface interface {
StopWatch()
SetDaemonVersion(newVersion string) bool
SetOnUpdateListener(updateFn func())
LatestVersion() *v.Version
StartFetcher()
}

View File

@@ -131,7 +131,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}

View File

@@ -51,7 +51,7 @@
</ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" />
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" TerminateProcess="0" />

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v6.32.1
// protoc v3.21.12
// source: daemon.proto
package proto
@@ -893,6 +893,7 @@ type UpRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
AutoUpdate *bool `protobuf:"varint,3,opt,name=autoUpdate,proto3,oneof" json:"autoUpdate,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -941,6 +942,13 @@ func (x *UpRequest) GetUsername() string {
return ""
}
func (x *UpRequest) GetAutoUpdate() bool {
if x != nil && x.AutoUpdate != nil {
return *x.AutoUpdate
}
return false
}
type UpResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -5356,6 +5364,94 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
return 0
}
type InstallerResultRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *InstallerResultRequest) Reset() {
*x = InstallerResultRequest{}
mi := &file_daemon_proto_msgTypes[79]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *InstallerResultRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*InstallerResultRequest) ProtoMessage() {}
func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[79]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{79}
}
type InstallerResultResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
ErrorMsg string `protobuf:"bytes,2,opt,name=errorMsg,proto3" json:"errorMsg,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *InstallerResultResponse) Reset() {
*x = InstallerResultResponse{}
mi := &file_daemon_proto_msgTypes[80]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *InstallerResultResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*InstallerResultResponse) ProtoMessage() {}
func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[80]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{80}
}
func (x *InstallerResultResponse) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
func (x *InstallerResultResponse) GetErrorMsg() string {
if x != nil {
return x.ErrorMsg
}
return ""
}
type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -5366,7 +5462,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[80]
mi := &file_daemon_proto_msgTypes[82]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -5378,7 +5474,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[80]
mi := &file_daemon_proto_msgTypes[82]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -5502,12 +5598,16 @@ const file_daemon_proto_rawDesc = "" +
"\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" +
"\bhostname\x18\x02 \x01(\tR\bhostname\",\n" +
"\x14WaitSSOLoginResponse\x12\x14\n" +
"\x05email\x18\x01 \x01(\tR\x05email\"p\n" +
"\x05email\x18\x01 \x01(\tR\x05email\"\xa4\x01\n" +
"\tUpRequest\x12%\n" +
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12#\n" +
"\n" +
"autoUpdate\x18\x03 \x01(\bH\x02R\n" +
"autoUpdate\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" +
"\t_username\"\f\n" +
"\t_usernameB\r\n" +
"\v_autoUpdate\"\f\n" +
"\n" +
"UpResponse\"\xa1\x01\n" +
"\rStatusRequest\x12,\n" +
@@ -5893,7 +5993,11 @@ const file_daemon_proto_rawDesc = "" +
"\x14WaitJWTTokenResponse\x12\x14\n" +
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" +
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" +
"\x16InstallerResultRequest\"O\n" +
"\x17InstallerResultResponse\x12\x18\n" +
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" +
"\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" +
@@ -5902,7 +6006,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xdb\x12\n" +
"\x05TRACE\x10\a2\xb4\x13\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -5938,7 +6042,8 @@ const file_daemon_proto_rawDesc = "" +
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" +
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00B\bZ\x06/protob\x06proto3"
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
var (
file_daemon_proto_rawDescOnce sync.Once
@@ -5953,7 +6058,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 82)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
@@ -6038,19 +6143,21 @@ var file_daemon_proto_goTypes = []any{
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
nil, // 83: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 84: daemon.PortInfo.Range
nil, // 85: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 86: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 87: google.protobuf.Timestamp
(*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest
(*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse
nil, // 85: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range
nil, // 87: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 88: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
86, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
87, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
87, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
86, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
@@ -6061,8 +6168,8 @@ var file_daemon_proto_depIdxs = []int32{
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
83, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
84, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -6073,10 +6180,10 @@ var file_daemon_proto_depIdxs = []int32{
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
87, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
85, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
86, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
@@ -6111,40 +6218,42 @@ var file_daemon_proto_depIdxs = []int32{
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
8, // 66: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 67: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 68: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 69: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 70: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 71: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 72: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 73: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 74: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 75: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 76: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 77: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 78: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 79: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 80: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 81: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 82: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 83: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 84: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 85: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 86: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 87: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 88: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 89: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 90: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 91: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 92: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 93: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 94: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 95: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 96: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
6, // 97: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
66, // [66:98] is the sub-list for method output_type
34, // [34:66] is the sub-list for method input_type
83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse
14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse
18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
67, // [67:100] is the sub-list for method output_type
34, // [34:67] is the sub-list for method input_type
34, // [34:34] is the sub-list for extension type_name
34, // [34:34] is the sub-list for extension extendee
0, // [0:34] is the sub-list for field type_name
@@ -6174,7 +6283,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 4,
NumMessages: 82,
NumMessages: 84,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -95,6 +95,8 @@ service DaemonService {
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
}
@@ -215,6 +217,7 @@ message WaitSSOLoginResponse {
message UpRequest {
optional string profileName = 1;
optional string username = 2;
optional bool autoUpdate = 3;
}
message UpResponse {}
@@ -772,3 +775,11 @@ message WaitJWTTokenResponse {
// expiration time in seconds
int64 expiresIn = 3;
}
message InstallerResultRequest {
}
message InstallerResultResponse {
bool success = 1;
string errorMsg = 2;
}

View File

@@ -71,6 +71,7 @@ type DaemonServiceClient interface {
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
}
type daemonServiceClient struct {
@@ -392,6 +393,15 @@ func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifec
return out, nil
}
func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) {
out := new(InstallerResultResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetInstallerResult", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -449,6 +459,7 @@ type DaemonServiceServer interface {
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -552,6 +563,9 @@ func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTo
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
}
func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1144,6 +1158,24 @@ func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Conte
return interceptor(ctx, in, info, handler)
}
func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(InstallerResultRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).GetInstallerResult(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/GetInstallerResult",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetInstallerResult(ctx, req.(*InstallerResultRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1275,6 +1307,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "NotifyOSLifecycle",
Handler: _DaemonService_NotifyOSLifecycle_Handler,
},
{
MethodName: "GetInstallerResult",
Handler: _DaemonService_GetInstallerResult_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -14,4 +14,4 @@ cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
cd "$old_pwd"
cd "$old_pwd"

View File

@@ -192,7 +192,7 @@ func (s *Server) Start() error {
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan)
return nil
}
@@ -223,7 +223,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() {
s.mutex.Lock()
s.clientRunning = false
@@ -231,7 +231,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
}()
if s.config.DisableAutoConnect {
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
@@ -260,7 +260,8 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
}()
runOperation := func() error {
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan)
doInitialAutoUpdate = false
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
@@ -728,7 +729,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
var doAutoUpdate bool
if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate {
doAutoUpdate = true
}
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
return s.waitForUp(callerCtx)
}
@@ -1539,9 +1545,9 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
return features, nil
}
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
return err

View File

@@ -112,7 +112,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}

View File

@@ -0,0 +1,30 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/client/proto"
)
func (s *Server) GetInstallerResult(ctx context.Context, _ *proto.InstallerResultRequest) (*proto.InstallerResultResponse, error) {
inst := installer.New()
dir := inst.TempDir()
rh := installer.NewResultHandler(dir)
result, err := rh.Watch(ctx)
if err != nil {
log.Errorf("failed to watch update result: %v", err)
return &proto.InstallerResultResponse{
Success: false,
ErrorMsg: err.Error(),
}, nil
}
return &proto.InstallerResultResponse{
Success: result.Success,
ErrorMsg: result.Error,
}, nil
}

View File

@@ -34,6 +34,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
protobuf "google.golang.org/protobuf/proto"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
@@ -43,7 +44,6 @@ import (
"github.com/netbirdio/netbird/client/ui/desktop"
"github.com/netbirdio/netbird/client/ui/event"
"github.com/netbirdio/netbird/client/ui/process"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
@@ -87,22 +87,24 @@ func main() {
// Create the service client (this also builds the settings or networks UI if requested).
client := newServiceClient(&newServiceClientArgs{
addr: flags.daemonAddr,
logFile: logFile,
app: a,
showSettings: flags.showSettings,
showNetworks: flags.showNetworks,
showLoginURL: flags.showLoginURL,
showDebug: flags.showDebug,
showProfiles: flags.showProfiles,
showQuickActions: flags.showQuickActions,
addr: flags.daemonAddr,
logFile: logFile,
app: a,
showSettings: flags.showSettings,
showNetworks: flags.showNetworks,
showLoginURL: flags.showLoginURL,
showDebug: flags.showDebug,
showProfiles: flags.showProfiles,
showQuickActions: flags.showQuickActions,
showUpdate: flags.showUpdate,
showUpdateVersion: flags.showUpdateVersion,
})
// Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set.
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions {
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate {
a.Run()
return
}
@@ -128,15 +130,17 @@ func main() {
}
type cliFlags struct {
daemonAddr string
showSettings bool
showNetworks bool
showProfiles bool
showDebug bool
showLoginURL bool
showQuickActions bool
errorMsg string
saveLogsInFile bool
daemonAddr string
showSettings bool
showNetworks bool
showProfiles bool
showDebug bool
showLoginURL bool
showQuickActions bool
errorMsg string
saveLogsInFile bool
showUpdate bool
showUpdateVersion string
}
// parseFlags reads and returns all needed command-line flags.
@@ -156,6 +160,8 @@ func parseFlags() *cliFlags {
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to")
flag.Parse()
return &flags
}
@@ -319,6 +325,8 @@ type serviceClient struct {
mExitNodeDeselectAll *systray.MenuItem
logFile string
wLoginURL fyne.Window
wUpdateProgress fyne.Window
updateContextCancel context.CancelFunc
connectCancel context.CancelFunc
}
@@ -329,15 +337,17 @@ type menuHandler struct {
}
type newServiceClientArgs struct {
addr string
logFile string
app fyne.App
showSettings bool
showNetworks bool
showDebug bool
showLoginURL bool
showProfiles bool
showQuickActions bool
addr string
logFile string
app fyne.App
showSettings bool
showNetworks bool
showDebug bool
showLoginURL bool
showProfiles bool
showQuickActions bool
showUpdate bool
showUpdateVersion string
}
// newServiceClient instance constructor
@@ -355,7 +365,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
showAdvancedSettings: args.showSettings,
showNetworks: args.showNetworks,
update: version.NewUpdate("nb/client-ui"),
update: version.NewUpdateAndStart("nb/client-ui"),
}
s.eventHandler = newEventHandler(s)
@@ -375,6 +385,8 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
s.showProfilesUI()
case args.showQuickActions:
s.showQuickActionsUI()
case args.showUpdate:
s.showUpdateProgress(ctx, args.showUpdateVersion)
}
return s
@@ -814,7 +826,7 @@ func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.Log
return nil
}
func (s *serviceClient) menuUpClick(ctx context.Context) error {
func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
@@ -836,7 +848,9 @@ func (s *serviceClient) menuUpClick(ctx context.Context) error {
return nil
}
if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil {
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{
AutoUpdate: protobuf.Bool(wannaAutoUpdate),
}); err != nil {
return fmt.Errorf("start connection: %w", err)
}
@@ -1097,6 +1111,26 @@ func (s *serviceClient) onTrayReady() {
s.updateExitNodes()
}
})
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
// todo use new Category
if windowAction, ok := event.Metadata["progress_window"]; ok {
targetVersion, ok := event.Metadata["version"]
if !ok {
targetVersion = "unknown"
}
log.Debugf("window action: %v", windowAction)
if windowAction == "show" {
if s.updateContextCancel != nil {
s.updateContextCancel()
s.updateContextCancel = nil
}
subCtx, cancel := context.WithCancel(s.ctx)
go s.eventHandler.runSelfCommand(subCtx, "update", "--update-version", targetVersion)
s.updateContextCancel = cancel
}
}
})
go s.eventManager.Start(s.ctx)
go s.eventHandler.listen(s.ctx)

View File

@@ -80,7 +80,7 @@ func (h *eventHandler) handleConnectClick() {
go func() {
defer connectCancel()
if err := h.client.menuUpClick(connectCtx); err != nil {
if err := h.client.menuUpClick(connectCtx, true); err != nil {
st, ok := status.FromError(err)
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
log.Debugf("connect operation cancelled by user")
@@ -185,7 +185,7 @@ func (h *eventHandler) handleAdvancedSettingsClick() {
go func() {
defer h.client.mAdvancedSettings.Enable()
defer h.client.getSrvConfig()
h.runSelfCommand(h.client.ctx, "settings", "true")
h.runSelfCommand(h.client.ctx, "settings")
}()
}
@@ -193,7 +193,7 @@ func (h *eventHandler) handleCreateDebugBundleClick() {
h.client.mCreateDebugBundle.Disable()
go func() {
defer h.client.mCreateDebugBundle.Enable()
h.runSelfCommand(h.client.ctx, "debug", "true")
h.runSelfCommand(h.client.ctx, "debug")
}()
}
@@ -217,7 +217,7 @@ func (h *eventHandler) handleNetworksClick() {
h.client.mNetworks.Disable()
go func() {
defer h.client.mNetworks.Enable()
h.runSelfCommand(h.client.ctx, "networks", "true")
h.runSelfCommand(h.client.ctx, "networks")
}()
}
@@ -237,17 +237,21 @@ func (h *eventHandler) updateConfigWithErr() error {
return nil
}
func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) {
func (h *eventHandler) runSelfCommand(ctx context.Context, command string, args ...string) {
proc, err := os.Executable()
if err != nil {
log.Errorf("error getting executable path: %v", err)
return
}
cmd := exec.CommandContext(ctx, proc,
fmt.Sprintf("--%s=%s", command, arg),
// Build the full command arguments
cmdArgs := []string{
fmt.Sprintf("--%s=true", command),
fmt.Sprintf("--daemon-addr=%s", h.client.addr),
)
}
cmdArgs = append(cmdArgs, args...)
cmd := exec.CommandContext(ctx, proc, cmdArgs...)
if out := h.client.attachOutput(cmd); out != nil {
defer func() {
@@ -257,17 +261,17 @@ func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string)
}()
}
log.Printf("running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, h.client.addr)
log.Printf("running command: %s", cmd.String())
if err := cmd.Run(); err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
log.Printf("command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode())
log.Printf("command '%s' failed with exit code %d", cmd.String(), exitErr.ExitCode())
}
return
}
log.Printf("command '%s %s' completed successfully", command, arg)
log.Printf("command '%s' completed successfully", cmd.String())
}
func (h *eventHandler) logout(ctx context.Context) error {

View File

@@ -397,7 +397,7 @@ type profileMenu struct {
logoutSubItem *subItem
profilesState []Profile
downClickCallback func() error
upClickCallback func(context.Context) error
upClickCallback func(context.Context, bool) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func()
app fyne.App
@@ -411,7 +411,7 @@ type newProfileMenuArgs struct {
profileMenuItem *systray.MenuItem
emailMenuItem *systray.MenuItem
downClickCallback func() error
upClickCallback func(context.Context) error
upClickCallback func(context.Context, bool) error
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
loadSettingsCallback func()
app fyne.App
@@ -579,7 +579,7 @@ func (p *profileMenu) refresh() {
connectCtx, connectCancel := context.WithCancel(p.ctx)
p.serviceClient.connectCancel = connectCancel
if err := p.upClickCallback(connectCtx); err != nil {
if err := p.upClickCallback(connectCtx, false); err != nil {
log.Errorf("failed to handle up click after switching profile: %v", err)
}

View File

@@ -267,7 +267,7 @@ func (s *serviceClient) showQuickActionsUI() {
connCmd := connectCommand{
connectClient: func() error {
return s.menuUpClick(s.ctx)
return s.menuUpClick(s.ctx, false)
},
}

140
client/ui/update.go Normal file
View File

@@ -0,0 +1,140 @@
//go:build !(linux && 386)
package main
import (
"context"
"errors"
"fmt"
"strings"
"time"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
)
func (s *serviceClient) showUpdateProgress(ctx context.Context, version string) {
log.Infof("show installer progress window: %s", version)
s.wUpdateProgress = s.app.NewWindow("Automatically updating client")
statusLabel := widget.NewLabel("Updating...")
infoLabel := widget.NewLabel(fmt.Sprintf("Your client version is older than the auto-update version set in Management.\nUpdating client to: %s.", version))
content := container.NewVBox(infoLabel, statusLabel)
s.wUpdateProgress.SetContent(content)
s.wUpdateProgress.CenterOnScreen()
s.wUpdateProgress.SetFixedSize(true)
s.wUpdateProgress.SetCloseIntercept(func() {
// this is empty to lock window until result known
})
s.wUpdateProgress.RequestFocus()
s.wUpdateProgress.Show()
updateWindowCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
// Initialize dot updater
updateText := dotUpdater()
// Channel to receive the result from RPC call
resultErrCh := make(chan error, 1)
resultOkCh := make(chan struct{}, 1)
// Start RPC call in background
go func() {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Infof("backend not reachable, upgrade in progress: %v", err)
close(resultOkCh)
return
}
resp, err := conn.GetInstallerResult(updateWindowCtx, &proto.InstallerResultRequest{})
if err != nil {
log.Infof("backend stopped responding, upgrade in progress: %v", err)
close(resultOkCh)
return
}
if !resp.Success {
resultErrCh <- mapInstallError(resp.ErrorMsg)
return
}
// Success
close(resultOkCh)
}()
// Update UI with dots and wait for result
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
defer cancel()
// allow closing update window after 10 sec
timerResetCloseInterceptor := time.NewTimer(10 * time.Second)
defer timerResetCloseInterceptor.Stop()
for {
select {
case <-updateWindowCtx.Done():
s.showInstallerResult(statusLabel, updateWindowCtx.Err())
return
case err := <-resultErrCh:
s.showInstallerResult(statusLabel, err)
return
case <-resultOkCh:
log.Info("backend exited, upgrade in progress, closing all UI")
killParentUIProcess()
s.app.Quit()
return
case <-ticker.C:
statusLabel.SetText(updateText())
case <-timerResetCloseInterceptor.C:
s.wUpdateProgress.SetCloseIntercept(nil)
}
}
}()
}
func (s *serviceClient) showInstallerResult(statusLabel *widget.Label, err error) {
s.wUpdateProgress.SetCloseIntercept(nil)
switch {
case errors.Is(err, context.DeadlineExceeded):
log.Warn("update watcher timed out")
statusLabel.SetText("Update timed out. Please try again.")
case errors.Is(err, context.Canceled):
log.Info("update watcher canceled")
statusLabel.SetText("Update canceled.")
case err != nil:
log.Errorf("update failed: %v", err)
statusLabel.SetText("Update failed: " + err.Error())
default:
s.wUpdateProgress.Close()
}
}
// dotUpdater returns a closure that cycles through dots for a loading animation.
func dotUpdater() func() string {
dotCount := 0
return func() string {
dotCount = (dotCount + 1) % 4
return fmt.Sprintf("%s%s", "Updating", strings.Repeat(".", dotCount))
}
}
func mapInstallError(msg string) error {
msg = strings.ToLower(strings.TrimSpace(msg))
switch {
case strings.Contains(msg, "deadline exceeded"), strings.Contains(msg, "timeout"):
return context.DeadlineExceeded
case strings.Contains(msg, "canceled"), strings.Contains(msg, "cancelled"):
return context.Canceled
case msg == "":
return errors.New("unknown update error")
default:
return errors.New(msg)
}
}

View File

@@ -0,0 +1,7 @@
//go:build !windows && !(linux && 386)
package main
func killParentUIProcess() {
// No-op on non-Windows platforms
}

View File

@@ -0,0 +1,44 @@
//go:build windows
package main
import (
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
nbprocess "github.com/netbirdio/netbird/client/ui/process"
)
// killParentUIProcess finds and kills the parent systray UI process on Windows.
// This is a workaround in case the MSI installer fails to properly terminate the UI process.
// The installer should handle this via util:CloseApplication with TerminateProcess, but this
// provides an additional safety mechanism to ensure the UI is closed before the upgrade proceeds.
func killParentUIProcess() {
pid, running, err := nbprocess.IsAnotherProcessRunning()
if err != nil {
log.Warnf("failed to check for parent UI process: %v", err)
return
}
if !running {
log.Debug("no parent UI process found to kill")
return
}
log.Infof("killing parent UI process (PID: %d)", pid)
// Open the process with terminate rights
handle, err := windows.OpenProcess(windows.PROCESS_TERMINATE, false, uint32(pid))
if err != nil {
log.Warnf("failed to open parent process %d: %v", pid, err)
return
}
defer func() {
_ = windows.CloseHandle(handle)
}()
// Terminate the process with exit code 0
if err := windows.TerminateProcess(handle, 0); err != nil {
log.Warnf("failed to terminate parent process %d: %v", pid, err)
}
}

View File

@@ -183,7 +183,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
s.update = version.NewUpdate("nb/management")
s.update = version.NewUpdateAndStart("nb/management")
s.update.SetDaemonVersion(version.NetbirdVersion())
s.update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())

View File

@@ -7,6 +7,7 @@ import (
"strings"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
@@ -101,6 +102,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
Fqdn: fqdn,
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: settings.LazyConnectionEnabled,
AutoUpdate: &proto.AutoUpdateSettings{
Version: settings.AutoUpdateVersion,
},
}
}
@@ -108,9 +112,10 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
},
Checks: toProtocolChecks(ctx, checks),
}

View File

@@ -321,7 +321,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain {
oldSettings.DNSDomain != newSettings.DNSDomain ||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
updateAccountPeers = true
}
@@ -360,6 +361,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
return nil, err
}
@@ -451,6 +453,14 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con
}
}
func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
if oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateVersionUpdated, map[string]any{
"version": newSettings.AutoUpdateVersion,
})
}
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {

View File

@@ -181,6 +181,8 @@ const (
UserRejected Activity = 90
UserCreated Activity = 91
AccountAutoUpdateVersionUpdated Activity = 92
AccountDeleted Activity = 99999
)
@@ -287,9 +289,12 @@ var activityMap = map[Activity]Code{
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
UserApproved: {"User approved", "user.approve"},
UserRejected: {"User rejected", "user.reject"},
UserCreated: {"User created", "user.create"},
UserApproved: {"User approved", "user.approve"},
UserRejected: {"User rejected", "user.reject"},
UserCreated: {"User created", "user.create"},
AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"},
}
// StringCode returns a string code of the activity

View File

@@ -3,12 +3,15 @@ package accounts
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/netip"
"time"
"github.com/gorilla/mux"
goversion "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/settings"
@@ -26,7 +29,9 @@ const (
// MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16)
MinNetworkBitsIPv4 = 28
// MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges
MinNetworkBitsIPv6 = 120
MinNetworkBitsIPv6 = 120
disableAutoUpdate = "disabled"
autoUpdateLatestVersion = "latest"
)
// handler is a handler that handles the server.Account HTTP endpoints
@@ -162,6 +167,61 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
returnSettings := &types.Settings{
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
}
if req.Settings.Extra != nil {
returnSettings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
}
}
if req.Settings.JwtGroupsEnabled != nil {
returnSettings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
}
if req.Settings.GroupsPropagationEnabled != nil {
returnSettings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled
}
if req.Settings.JwtGroupsClaimName != nil {
returnSettings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
}
if req.Settings.JwtAllowGroups != nil {
returnSettings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
returnSettings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
}
if req.Settings.DnsDomain != nil {
returnSettings.DNSDomain = *req.Settings.DnsDomain
}
if req.Settings.LazyConnectionEnabled != nil {
returnSettings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
}
if req.Settings.AutoUpdateVersion != nil {
_, err := goversion.NewSemver(*req.Settings.AutoUpdateVersion)
if *req.Settings.AutoUpdateVersion == autoUpdateLatestVersion ||
*req.Settings.AutoUpdateVersion == disableAutoUpdate ||
err == nil {
returnSettings.AutoUpdateVersion = *req.Settings.AutoUpdateVersion
} else if *req.Settings.AutoUpdateVersion != "" {
return nil, fmt.Errorf("invalid AutoUpdateVersion")
}
}
return returnSettings, nil
}
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
@@ -186,45 +246,10 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
return
}
settings := &types.Settings{
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
}
if req.Settings.Extra != nil {
settings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
}
}
if req.Settings.JwtGroupsEnabled != nil {
settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
}
if req.Settings.GroupsPropagationEnabled != nil {
settings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled
}
if req.Settings.JwtGroupsClaimName != nil {
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
}
if req.Settings.JwtAllowGroups != nil {
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
}
if req.Settings.DnsDomain != nil {
settings.DNSDomain = *req.Settings.DnsDomain
}
if req.Settings.LazyConnectionEnabled != nil {
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
settings, err := h.updateAccountRequestSettings(req)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" {
prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange)
@@ -313,6 +338,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
DnsDomain: &settings.DNSDomain,
AutoUpdateVersion: &settings.AutoUpdateVersion,
}
if settings.NetworkRange.IsValid() {

View File

@@ -121,6 +121,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
},
expectedArray: true,
expectedID: accountID,
@@ -143,6 +144,30 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
},
expectedArray: false,
expectedID: accountID,
},
{
name: "PutAccount OK with autoUpdateVersion",
expectedBody: true,
requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"auto_update_version\": \"latest\", \"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000,
PeerLoginExpirationEnabled: true,
GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr(""),
JwtGroupsEnabled: br(false),
JwtAllowGroups: &[]string{},
RegularUsersViewBlocked: false,
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr("latest"),
},
expectedArray: false,
expectedID: accountID,
@@ -165,6 +190,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
},
expectedArray: false,
expectedID: accountID,
@@ -187,6 +213,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
},
expectedArray: false,
expectedID: accountID,
@@ -209,6 +236,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
},
expectedArray: false,
expectedID: accountID,

View File

@@ -52,6 +52,9 @@ type Settings struct {
// LazyConnectionEnabled indicates if the experimental feature is enabled or disabled
LazyConnectionEnabled bool `gorm:"default:false"`
// AutoUpdateVersion client auto-update version
AutoUpdateVersion string `gorm:"default:'disabled'"`
}
// Copy copies the Settings struct
@@ -72,6 +75,7 @@ func (s *Settings) Copy() *Settings {
LazyConnectionEnabled: s.LazyConnectionEnabled,
DNSDomain: s.DNSDomain,
NetworkRange: s.NetworkRange,
AutoUpdateVersion: s.AutoUpdateVersion,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()

View File

@@ -145,6 +145,10 @@ components:
description: Enables or disables experimental lazy connection
type: boolean
example: true
auto_update_version:
description: Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1")
type: string
example: "0.51.2"
required:
- peer_login_expiration_enabled
- peer_login_expiration

View File

@@ -291,6 +291,9 @@ type AccountRequest struct {
// AccountSettings defines model for AccountSettings.
type AccountSettings struct {
// AutoUpdateVersion Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1")
AutoUpdateVersion *string `json:"auto_update_version,omitempty"`
// DnsDomain Allows to define a custom dns domain for the account
DnsDomain *string `json:"dns_domain,omitempty"`
Extra *AccountExtraSettings `json:"extra,omitempty"`

File diff suppressed because it is too large Load Diff

View File

@@ -280,6 +280,18 @@ message PeerConfig {
bool LazyConnectionEnabled = 6;
int32 mtu = 7;
// Auto-update config
AutoUpdateSettings autoUpdate = 8;
}
message AutoUpdateSettings {
string version = 1;
/*
alwaysUpdate = true → Updates happen automatically in the background
alwaysUpdate = false → Updates only happen when triggered by a peer connection
*/
bool alwaysUpdate = 2;
}
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections

View File

@@ -41,21 +41,28 @@ func NewUpdate(httpAgent string) *Update {
currentVersion, _ = goversion.NewVersion("0.0.0")
}
latestAvailable, _ := goversion.NewVersion("0.0.0")
u := &Update{
httpAgent: httpAgent,
latestAvailable: latestAvailable,
uiVersion: currentVersion,
fetchTicker: time.NewTicker(fetchPeriod),
fetchDone: make(chan struct{}),
httpAgent: httpAgent,
uiVersion: currentVersion,
fetchDone: make(chan struct{}),
}
go u.startFetcher()
return u
}
func NewUpdateAndStart(httpAgent string) *Update {
u := NewUpdate(httpAgent)
go u.StartFetcher()
return u
}
// StopWatch stop the version info fetch loop
func (u *Update) StopWatch() {
if u.fetchTicker == nil {
return
}
u.fetchTicker.Stop()
select {
@@ -94,7 +101,18 @@ func (u *Update) SetOnUpdateListener(updateFn func()) {
}
}
func (u *Update) startFetcher() {
func (u *Update) LatestVersion() *goversion.Version {
u.versionsLock.Lock()
defer u.versionsLock.Unlock()
return u.latestAvailable
}
func (u *Update) StartFetcher() {
if u.fetchTicker != nil {
return
}
u.fetchTicker = time.NewTicker(fetchPeriod)
if changed := u.fetchVersion(); changed {
u.checkUpdate()
}
@@ -181,6 +199,10 @@ func (u *Update) isUpdateAvailable() bool {
u.versionsLock.Lock()
defer u.versionsLock.Unlock()
if u.latestAvailable == nil {
return false
}
if u.latestAvailable.GreaterThan(u.uiVersion) {
return true
}

View File

@@ -23,7 +23,7 @@ func TestNewUpdate(t *testing.T) {
wg.Add(1)
onUpdate := false
u := NewUpdate(httpAgent)
u := NewUpdateAndStart(httpAgent)
defer u.StopWatch()
u.SetOnUpdateListener(func() {
onUpdate = true
@@ -48,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) {
wg.Add(1)
onUpdate := false
u := NewUpdate(httpAgent)
u := NewUpdateAndStart(httpAgent)
defer u.StopWatch()
u.SetOnUpdateListener(func() {
onUpdate = true
@@ -73,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) {
wg.Add(1)
onUpdate := false
u := NewUpdate(httpAgent)
u := NewUpdateAndStart(httpAgent)
defer u.StopWatch()
u.SetOnUpdateListener(func() {
onUpdate = true