[client] Consolidate authentication logic (#5010)

* Consolidate authentication logic

- Moving auth functions from client/internal to client/internal/auth package
- Creating unified auth.Auth client with NewAuth() constructor
- Replacing direct auth function calls with auth client methods
- Refactoring device flow and PKCE flow implementations
- Updating iOS/Android/server code to use new auth client API

* Refactor PKCE auth and login methods

- Remove unnecessary internal package reference in PKCE flow test
- Adjust context assignment placement in iOS and Android login methods
This commit is contained in:
Zoltan Papp
2026-01-23 22:28:32 +01:00
committed by GitHub
parent 67211010f7
commit ded04b7627
16 changed files with 805 additions and 724 deletions

View File

@@ -3,15 +3,7 @@ package android
import (
"context"
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
@@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
return err
}
return err
})
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
@@ -129,19 +108,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
err, _ = authClient.Login(ctxWithValues, setupKey, "")
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
@@ -160,49 +137,41 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
}
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
var needsLogin bool
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
needsLogin, err := authClient.IsLoginRequired(a.ctx)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("failed to check login requirement: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
err, _ = authClient.Login(a.ctx, "", jwtToken)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
go urlOpener.OnLoginSuccess()
return nil
}
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
}
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
@@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}

View File

@@ -7,7 +7,6 @@ import (
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -277,18 +276,19 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
needsLogin := false
err := WithBackOff(func() error {
err := internal.Login(ctx, config, "", "")
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
needsLogin = true
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
err, isAuthError := authClient.Login(ctx, "", "")
if isAuthError {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
}
jwtToken := ""
@@ -300,23 +300,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken = tokenInfo.GetTokenToUse()
}
var lastError error
err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil
}
return err
})
if lastError != nil {
return fmt.Errorf("login failed: %v", lastError)
}
err, _ = authClient.Login(ctx, setupKey, jwtToken)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
return nil
@@ -344,11 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
defer c()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
@@ -176,7 +177,13 @@ func (c *Client) Start(startCtx context.Context) error {
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
}
defer authClient.Close()
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}

View File

@@ -0,0 +1,499 @@
package auth
import (
"context"
"net/url"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// Auth manages authentication operations with the management server
// It maintains a long-lived connection and automatically handles reconnection with backoff
type Auth struct {
mutex sync.RWMutex
client *mgm.GrpcClient
config *profilemanager.Config
privateKey wgtypes.Key
mgmURL *url.URL
mgmTLSEnabled bool
}
// NewAuth creates a new Auth instance that manages authentication flows
// It establishes a connection to the management server that will be reused for all operations
// The connection is automatically recreated with backoff if it becomes disconnected
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
// Validate WireGuard private key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
return nil, err
}
// Determine TLS setting based on URL scheme
mgmTLSEnabled := mgmURL.Scheme == "https"
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
return nil, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
return &Auth{
client: mgmClient,
config: config,
privateKey: myPrivateKey,
mgmURL: mgmURL,
mgmTLSEnabled: mgmTLSEnabled,
}, nil
}
// Close closes the management client connection
func (a *Auth) Close() error {
a.mutex.Lock()
defer a.mutex.Unlock()
if a.client == nil {
return nil
}
return a.client.Close()
}
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
var supportsSSO bool
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
// Try PKCE flow first
_, err := a.getPKCEFlow(client)
if err == nil {
supportsSSO = true
return nil
}
// Check if PKCE is not supported
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
// PKCE not supported, try Device flow
_, err = a.getDeviceFlow(client)
if err == nil {
supportsSSO = true
return nil
}
// Check if Device flow is also not supported
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
// Neither PKCE nor Device flow is supported
supportsSSO = false
return nil
}
// Device flow check returned an error other than NotFound/Unimplemented
return err
}
// PKCE flow check returned an error other than NotFound/Unimplemented
return err
})
return supportsSSO, err
}
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
// This avoids creating a new connection to the management server
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
var flow OAuthFlow
var err error
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
if forceDeviceAuth {
flow, err = a.getDeviceFlow(client)
return err
}
// Try PKCE flow first
flow, err = a.getPKCEFlow(client)
if err != nil {
// If PKCE not supported, try Device flow
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
flow, err = a.getDeviceFlow(client)
return err
}
return err
}
return nil
})
return flow, err
}
// IsLoginRequired checks if login is required by attempting to authenticate with the server
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return false, err
}
var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) {
needsLogin = true
return nil
}
needsLogin = false
return err
})
return needsLogin, err
}
// Login attempts to log in or register the client with the management server
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
// Automatically retries with backoff and reconnection on connection errors.
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return err, false
}
var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
isAuthError = isPermissionDenied(err)
return err
}
} else if err != nil {
isAuthError = isPermissionDenied(err)
return err
}
isAuthError = false
return nil
})
return err, isAuthError
}
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
RedirectURLs: protoConfig.GetRedirectURLs(),
UseIDToken: protoConfig.GetUseIDToken(),
ClientCertPair: a.config.ClientCertKeyPair,
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
}
if err := validatePKCEConfig(config); err != nil {
return nil, err
}
flow, err := NewPKCEAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
Scope: protoConfig.GetScope(),
UseIDToken: protoConfig.GetUseIDToken(),
}
// Keep compatibility with older management versions
if config.Scope == "" {
config.Scope = "openid"
}
if err := validateDeviceAuthConfig(config); err != nil {
return nil, err
}
flow, err := NewDeviceAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo)
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
a.setSystemInfoFlags(info)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
// setSystemInfoFlags sets all configuration flags on the provided system info
func (a *Auth) setSystemInfoFlags(info *system.Info) {
info.SetFlags(
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
a.config.DisableFirewall,
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
)
}
// reconnect closes the current connection and creates a new one
// It checks if the brokenClient is still the current client before reconnecting
// to avoid multiple threads reconnecting unnecessarily
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
a.mutex.Lock()
defer a.mutex.Unlock()
// Double-check: if client has already been replaced by another thread, skip reconnection
if a.client != brokenClient {
log.Debugf("client already reconnected by another thread, skipping")
return nil
}
// Create new connection FIRST, before closing the old one
// This ensures a.client is never nil, preventing panics in other threads
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
if err != nil {
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
// Keep the old client if reconnection fails
return err
}
// Close old connection AFTER new one is successfully created
oldClient := a.client
a.client = mgmClient
if oldClient != nil {
if err := oldClient.Close(); err != nil {
log.Debugf("error closing old connection: %v", err)
}
}
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
return nil
}
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
func isConnectionError(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
// These error codes indicate connection issues
return s.Code() == codes.Unavailable ||
s.Code() == codes.DeadlineExceeded ||
s.Code() == codes.Canceled ||
s.Code() == codes.Internal
}
// withRetry wraps an operation with exponential backoff retry logic
// It automatically reconnects on connection errors
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
backoffSettings := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.5,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 2 * time.Minute,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
backoffSettings.Reset()
return backoff.RetryNotify(
func() error {
// Capture the client BEFORE the operation to ensure we track the correct client
a.mutex.RLock()
currentClient := a.client
a.mutex.RUnlock()
if currentClient == nil {
return status.Errorf(codes.Unavailable, "client is not initialized")
}
// Execute operation with the captured client
err := operation(currentClient)
if err == nil {
return nil
}
// If it's a connection error, attempt reconnection using the client that was actually used
if isConnectionError(err) {
log.Warnf("connection error detected, attempting reconnection: %v", err)
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
log.Errorf("reconnection failed: %v", reconnectErr)
return reconnectErr
}
// Return the original error to trigger retry with the new connection
return err
}
// For authentication errors, don't retry
if isAuthenticationError(err) {
return backoff.Permanent(err)
}
return err
},
backoff.WithContext(backoffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("operation failed, retrying in %v: %v", duration, err)
},
)
}
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
func isAuthenticationError(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
}
// isPermissionDenied checks if the error is a PermissionDenied error.
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
func isPermissionDenied(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
return s.Code() == codes.PermissionDenied
}
func isLoginNeeded(err error) bool {
return isAuthenticationError(err)
}
func isRegistrationNeeded(err error) bool {
return isPermissionDenied(err)
}

View File

@@ -15,7 +15,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/embeddedroots"
)
@@ -26,12 +25,56 @@ const (
var _ OAuthFlow = &DeviceAuthorizationFlow{}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validateDeviceAuthConfig validates device authorization provider configuration
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMsgFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
}
return nil
}
// DeviceAuthorizationFlow implements the OAuthFlow interface,
// for the Device Authorization Flow.
type DeviceAuthorizationFlow struct {
providerConfig internal.DeviceAuthProviderConfig
HTTPClient HTTPClient
providerConfig DeviceAuthProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
@@ -57,7 +100,7 @@ type TokenRequestResponse struct {
}
// NewDeviceAuthorizationFlow returns device authorization flow client
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
return d.providerConfig.ClientID
}
// SetLoginHint sets the login hint for the device authorization flow
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
d.providerConfig.LoginHint = hint
}
// RequestAuthInfo requests a device code login flow information from Hosted
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
form := url.Values{}
@@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning
// it retrieves the access token from Hosted's endpoint and validates it before returning.
// The method creates a timeout context internally based on info.ExpiresIn.
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case <-ticker.C:
tokenResponse, err := d.requestToken(info)

View File

@@ -12,8 +12,6 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockHTTPClient struct {
@@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
err: testCase.inputReqError,
}
deviceFlow := &DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
},
HTTPClient: &httpClient,
config := DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
@@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) {
countResBody: testCase.inputCountResBody,
}
deviceFlow := DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
},
HTTPClient: &httpClient,
config := DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
defer cancel()
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)

View File

@@ -10,7 +10,6 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
pkceFlowInfo.ProviderConfig.LoginHint = hint
if hint != "" {
pkceFlowInfo.SetLoginHint(hint)
}
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
return pkceFlowInfo, nil
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
if err != nil {
switch s, ok := gstatus.FromError(err); {
case ok && s.Code() == codes.NotFound:
@@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
}
}
deviceFlowInfo.ProviderConfig.LoginHint = hint
if hint != "" {
deviceFlowInfo.SetLoginHint(hint)
}
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
return deviceFlowInfo, nil
}

View File

@@ -20,7 +20,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
"github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -35,17 +34,67 @@ const (
defaultPKCETimeoutSeconds = 300
)
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validatePKCEConfig validates PKCE provider configuration
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
}
return nil
}
// PKCEAuthorizationFlow implements the OAuthFlow interface for
// the Authorization Code Flow with PKCE.
type PKCEAuthorizationFlow struct {
providerConfig internal.PKCEAuthProviderConfig
providerConfig PKCEAuthProviderConfig
state string
codeVerifier string
oAuthConfig *oauth2.Config
}
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
excludedRanges := getSystemExcludedPortRanges()
@@ -124,10 +173,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
}, nil
}
// SetLoginHint sets the login hint for the PKCE authorization flow
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
p.providerConfig.LoginHint = hint
}
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
// Once the token is received, it is converted to TokenInfo and validated before returning.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
// The method creates a timeout context internally based on info.ExpiresIn.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
@@ -138,7 +198,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
@@ -149,8 +209,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
go p.startServer(server, tokenChan, errChan)
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case token := <-tokenChan:
return p.parseOAuthToken(token)
case err := <-errChan:

View File

@@ -9,7 +9,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -50,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := internal.PKCEAuthProviderConfig{
config := PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -9,8 +9,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
func TestParseExcludedPortRanges(t *testing.T) {
@@ -95,7 +93,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
availablePort := 65432
config := internal.PKCEAuthProviderConfig{
config := PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -1,136 +0,0 @@
package internal
import (
"context"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// DeviceAuthorizationFlow represents Device Authorization Flow information
type DeviceAuthorizationFlow struct {
Provider string
ProviderConfig DeviceAuthProviderConfig
}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return DeviceAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return DeviceAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return DeviceAuthorizationFlow{}, err
}
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return DeviceAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return DeviceAuthorizationFlow{}, err
}
deviceAuthorizationFlow := DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: DeviceAuthProviderConfig{
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
},
}
// keep compatibility with older management versions
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
}
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil {
return DeviceAuthorizationFlow{}, err
}
return deviceAuthorizationFlow, nil
}
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
}
return nil
}

View File

@@ -1,201 +0,0 @@
package internal
import (
"context"
"net/url"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// IsLoginRequired check that the server is support SSO or not
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
mgmURL := config.ManagementURL
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
if err != nil {
return false, err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", mgmURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return false, err
}
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) {
return true, nil
}
return false, err
}
// Login or register the client
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return err
}
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
if err != nil {
return err
}
} else if err != nil {
return err
}
return nil
}
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return nil, err
}
var mgmTlsEnabled bool
if mgmURL.Scheme == "https" {
mgmTlsEnabled = true
}
log.Debugf("connecting to the Management service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
return nil, err
}
return mgmClient, err
}
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
info.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
func isLoginNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
return true
}
return false
}
func isRegistrationNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.PermissionDenied {
return true
}
return false
}

View File

@@ -1,138 +0,0 @@
package internal
import (
"context"
"crypto/tls"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
)
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
type PKCEAuthorizationFlow struct {
ProviderConfig PKCEAuthProviderConfig
}
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return PKCEAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return PKCEAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return PKCEAuthorizationFlow{}, err
}
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return PKCEAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return PKCEAuthorizationFlow{}, err
}
authFlow := PKCEAuthorizationFlow{
ProviderConfig: PKCEAuthProviderConfig{
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
},
}
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
if err != nil {
return PKCEAuthorizationFlow{}, err
}
return authFlow, nil
}
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
}
return nil
}

View File

@@ -263,7 +263,14 @@ func (c *Client) IsLoginRequired() bool {
return true
}
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
if err != nil {
log.Errorf("IsLoginRequired: failed to create auth client: %v", err)
return true // Assume login is required if we can't create auth client
}
defer authClient.Close()
needsLogin, err := authClient.IsLoginRequired(ctx)
if err != nil {
log.Errorf("IsLoginRequired: check failed: %v", err)
// If the check fails, assume login is required to be safe
@@ -314,16 +321,19 @@ func (c *Client) LoginForMobile() string {
// This could cause a potential race condition with loading the extension which need to be handled on swift side
go func() {
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
if err != nil {
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
return
}
jwtToken := tokenInfo.GetTokenToUse()
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
if err != nil {
log.Errorf("LoginForMobile: failed to create auth client: %v", err)
return
}
defer authClient.Close()
if err, _ := authClient.Login(ctx, "", jwtToken); err != nil {
log.Errorf("LoginForMobile: Login failed: %v", err)
return
}

View File

@@ -7,13 +7,8 @@ import (
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
@@ -90,34 +85,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
return err
}
return err
})
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
// which are blocked by the tvOS sandbox in App Group containers
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
@@ -141,19 +123,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
}
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
err, _ = authClient.Login(ctxWithValues, setupKey, "")
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
@@ -164,15 +144,16 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
// LoginSync performs a synchronous login check without UI interaction
// Used for background VPN connection where user should already be authenticated
func (a *Auth) LoginSync() error {
var needsLogin bool
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
needsLogin, err := authClient.IsLoginRequired(a.ctx)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("failed to check login requirement: %v", err)
}
jwtToken := ""
@@ -180,15 +161,12 @@ func (a *Auth) LoginSync() error {
return fmt.Errorf("not authenticated")
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// PermissionDenied means registration is required or peer is blocked
return backoff.Permanent(err)
}
return err
})
err, isAuthError := authClient.Login(a.ctx, "", jwtToken)
if err != nil {
if isAuthError {
// PermissionDenied means registration is required or peer is blocked
return fmt.Errorf("authentication error: %v", err)
}
return fmt.Errorf("login failed: %v", err)
}
@@ -225,8 +203,6 @@ func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpen
}
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
var needsLogin bool
// Create context with device name if provided
ctx := a.ctx
if deviceName != "" {
@@ -234,33 +210,33 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
}
// check if we need to generate JWT token
err := a.withBackOff(ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
return
})
authClient, err := auth.NewAuth(ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
// check if we need to generate JWT token
needsLogin, err := authClient.IsLoginRequired(ctx)
if err != nil {
return fmt.Errorf("failed to check login requirement: %v", err)
}
jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, forceDeviceAuth)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}
err = a.withBackOff(ctx, func() error {
err := internal.Login(ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
// PermissionDenied means registration is required or peer is blocked
return backoff.Permanent(err)
}
return err
})
err, isAuthError := authClient.Login(ctx, "", jwtToken)
if err != nil {
if isAuthError {
// PermissionDenied means registration is required or peer is blocked
return fmt.Errorf("authentication error: %v", err)
}
return fmt.Errorf("login failed: %v", err)
}
@@ -285,10 +261,10 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
const authInfoRequestTimeout = 30 * time.Second
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, forceDeviceAuth)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
}
// Use a bounded timeout for the auth info request to prevent indefinite hangs
@@ -313,15 +289,6 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool)
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}
// GetConfigJSON returns the current config as a JSON string.
// This can be used by the caller to persist the config via alternative storage
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).

View File

@@ -253,10 +253,17 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
var status internal.StatusType
err := internal.Login(ctx, s.config, setupKey, jwtToken)
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
if err != nil {
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
log.Errorf("failed to create auth client: %v", err)
return internal.StatusLoginFailed, err
}
defer authClient.Close()
var status internal.StatusType
err, isAuthError := authClient.Login(ctx, setupKey, jwtToken)
if err != nil {
if isAuthError {
log.Warnf("failed login: %v", err)
status = internal.StatusNeedsLogin
} else {
@@ -581,8 +588,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.oauthAuthFlow.waitCancel()
}
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
waitCTX, cancel := context.WithCancel(ctx)
defer cancel()
s.mutex.Lock()