Compare commits

...

1 Commits

Author SHA1 Message Date
Zoltán Papp
4b3e1f1b52 Refactor gRPC auth and retry logic
- Centralize retry logic in auth layer
- Decouple gRPC connection logic with new Connect method
- Refactor management client to fetch server public key internally
- Add dedicated HealthCheck method for connection verification
- Simplify getServerPublicKey by removing retry logic
2025-12-24 11:26:51 +01:00
23 changed files with 1008 additions and 932 deletions

View File

@@ -3,15 +3,7 @@ package android
import (
"context"
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
@@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
return err
}
return err
})
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
@@ -131,17 +110,15 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
err = authClient.Login(ctxWithValues, setupKey, "")
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
@@ -160,15 +137,16 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
}
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
var needsLogin bool
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
needsLogin, err := authClient.IsLoginRequired(a.ctx)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("failed to check login requirement: %v", err)
}
jwtToken := ""
@@ -180,22 +158,13 @@ func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
jwtToken = tokenInfo.GetTokenToUse()
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
err = authClient.Login(a.ctx, "", jwtToken)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
go urlOpener.OnLoginSuccess()
return nil
}
@@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}
return &tokenInfo, nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}

View File

@@ -2,13 +2,13 @@ package cmd
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/util"
)
@@ -277,18 +278,19 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
needsLogin := false
err := WithBackOff(func() error {
err := internal.Login(ctx, config, "", "")
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
needsLogin = true
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
err = authClient.Login(ctx, "", "")
if errors.Is(err, mgm.ErrPermissionDenied) || errors.Is(err, mgm.ErrInvalidArgument) || errors.Is(err, mgm.ErrUnauthenticated) {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
}
jwtToken := ""
@@ -300,23 +302,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken = tokenInfo.GetTokenToUse()
}
var lastError error
err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil
}
return err
})
if lastError != nil {
return fmt.Errorf("login failed: %v", lastError)
}
err = authClient.Login(ctx, setupKey, jwtToken)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
return nil
@@ -344,11 +332,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
defer c()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
@@ -168,7 +169,13 @@ func (c *Client) Start(startCtx context.Context) error {
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
}
defer authClient.Close()
if err := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}

View File

@@ -0,0 +1,287 @@
package auth
import (
"context"
"errors"
"fmt"
"net/url"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
)
// Auth manages authentication operations with the management server
// The underlying management client handles connection retry and reconnection automatically
type Auth struct {
client *mgm.GrpcClient
config *profilemanager.Config
}
// NewAuth creates a new Auth instance that manages authentication flows
// It establishes a connection to the management server that will be reused for all operations
// The management client handles connection retry and reconnection automatically
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
// Validate WireGuard private key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return nil, err
}
// Determine TLS setting based on URL scheme
mgmTLSEnabled := mgmURL.Scheme == "https"
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient := mgm.NewClient(mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err := mgmClient.Connect(ctx); err != nil {
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
return nil, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
return &Auth{
client: mgmClient,
config: config,
}, nil
}
// Close closes the management client connection
func (a *Auth) Close() error {
if a.client == nil {
return nil
}
return a.client.Close()
}
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
// Try PKCE flow first
_, err := a.getPKCEFlow(ctx)
if err == nil {
return true, nil
}
// Check if PKCE is not supported
if errors.Is(err, mgm.ErrNotFound) || errors.Is(err, mgm.ErrUnimplemented) {
// PKCE not supported, try Device flow
_, err = a.getDeviceFlow(ctx)
if err == nil {
return true, nil
}
// Check if Device flow is also not supported
if errors.Is(err, mgm.ErrNotFound) || errors.Is(err, mgm.ErrUnimplemented) {
// Neither PKCE nor Device flow is supported
return false, nil
}
// Device flow check returned an error other than NotFound/Unimplemented
return false, err
}
// PKCE flow check returned an error other than NotFound/Unimplemented
return false, err
}
// IsLoginRequired checks if login is required by attempting to authenticate with the server
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return false, err
}
err = a.doMgmLogin(ctx, pubSSHKey)
if isLoginNeeded(err) {
return true, nil
}
return false, err
}
// Login attempts to log in or register the client with the management server
// Returns custom errors from mgm package: ErrPermissionDenied, ErrInvalidArgument, ErrUnauthenticated
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) error {
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
if err != nil {
return fmt.Errorf("generate SSH public key: %w", err)
}
err = a.doMgmLogin(ctx, pubSSHKey)
if isRegistrationNeeded(err) {
log.Debugf("peer registration required")
return a.registerPeer(ctx, setupKey, jwtToken, pubSSHKey)
}
return err
}
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(ctx context.Context) (*PKCEAuthorizationFlow, error) {
protoFlow, err := a.client.GetPKCEAuthorizationFlow(ctx)
if err != nil {
if errors.Is(err, mgm.ErrNotFound) {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
RedirectURLs: protoConfig.GetRedirectURLs(),
UseIDToken: protoConfig.GetUseIDToken(),
ClientCertPair: a.config.ClientCertKeyPair,
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
}
if err := validatePKCEConfig(config); err != nil {
return nil, err
}
flow, err := NewPKCEAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(ctx context.Context) (*DeviceAuthorizationFlow, error) {
protoFlow, err := a.client.GetDeviceAuthorizationFlow(ctx)
if err != nil {
if errors.Is(err, mgm.ErrNotFound) {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return nil, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return nil, err
}
protoConfig := protoFlow.GetProviderConfig()
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
Scope: protoConfig.GetScope(),
UseIDToken: protoConfig.GetUseIDToken(),
}
// Keep compatibility with older management versions
if config.Scope == "" {
config.Scope = "openid"
}
if err := validateDeviceAuthConfig(config); err != nil {
return nil, err
}
flow, err := NewDeviceAuthorizationFlow(*config)
if err != nil {
return nil, err
}
return flow, nil
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(ctx context.Context, pubSSHKey []byte) error {
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
a.config.DisableFirewall,
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
)
_, err := a.client.Login(ctx, sysInfo, pubSSHKey, a.config.DNSLabels)
return err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) error {
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return fmt.Errorf("%w: invalid setup-key or no SSO information provided: %v", mgm.ErrInvalidArgument, err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
info.SetFlags(
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
a.config.DisableFirewall,
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
)
// todo: fix error handling of validSetupKey
if err := a.client.Register(ctx, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels); err != nil {
log.Errorf("failed registering peer %v", err)
return err
}
log.Infof("peer has been successfully registered on Management Service")
return nil
}
// isPermissionDenied checks if the error is a PermissionDenied error
func isPermissionDenied(err error) bool {
return errors.Is(err, mgm.ErrPermissionDenied)
}
// isLoginNeeded checks if the error indicates login is required
func isLoginNeeded(err error) bool {
if err == nil {
return false
}
return errors.Is(err, mgm.ErrInvalidArgument) ||
errors.Is(err, mgm.ErrPermissionDenied) ||
errors.Is(err, mgm.ErrUnauthenticated)
}
// isRegistrationNeeded checks if the error indicates peer registration is needed
func isRegistrationNeeded(err error) bool {
return isPermissionDenied(err)
}

View File

@@ -15,7 +15,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/embeddedroots"
)
@@ -26,12 +25,56 @@ const (
var _ OAuthFlow = &DeviceAuthorizationFlow{}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validateDeviceAuthConfig validates device authorization provider configuration
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMsgFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
}
return nil
}
// DeviceAuthorizationFlow implements the OAuthFlow interface,
// for the Device Authorization Flow.
type DeviceAuthorizationFlow struct {
providerConfig internal.DeviceAuthProviderConfig
HTTPClient HTTPClient
providerConfig DeviceAuthProviderConfig
HTTPClient HTTPClient
}
// RequestDeviceCodePayload used for request device code payload for auth0
@@ -57,7 +100,7 @@ type TokenRequestResponse struct {
}
// NewDeviceAuthorizationFlow returns device authorization flow client
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
return d.providerConfig.ClientID
}
// SetLoginHint sets the login hint for the device authorization flow
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
d.providerConfig.LoginHint = hint
}
// RequestAuthInfo requests a device code login flow information from Hosted
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
form := url.Values{}
@@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning
// it retrieves the access token from Hosted's endpoint and validates it before returning.
// The method creates a timeout context internally based on info.ExpiresIn.
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
interval := time.Duration(info.Interval) * time.Second
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case <-ticker.C:
tokenResponse, err := d.requestToken(info)

View File

@@ -12,8 +12,6 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
)
type mockHTTPClient struct {
@@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
err: testCase.inputReqError,
}
deviceFlow := &DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
},
HTTPClient: &httpClient,
config := DeviceAuthProviderConfig{
Audience: expectedAudience,
ClientID: expectedClientID,
Scope: expectedScope,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
UseIDToken: false,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
@@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) {
countResBody: testCase.inputCountResBody,
}
deviceFlow := DeviceAuthorizationFlow{
providerConfig: internal.DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
},
HTTPClient: &httpClient,
config := DeviceAuthProviderConfig{
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
Scope: "openid",
UseIDToken: false,
}
deviceFlow, err := NewDeviceAuthorizationFlow(config)
require.NoError(t, err, "creating device flow should not fail")
deviceFlow.HTTPClient = &httpClient
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
defer cancel()
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)

View File

@@ -10,7 +10,6 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
@@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
pkceFlowInfo, err := authClient.getPKCEFlow(ctx)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
pkceFlowInfo.ProviderConfig.LoginHint = hint
if hint != "" {
pkceFlowInfo.SetLoginHint(hint)
}
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
return pkceFlowInfo, nil
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
deviceFlowInfo, err := authClient.getDeviceFlow(ctx)
if err != nil {
switch s, ok := gstatus.FromError(err); {
case ok && s.Code() == codes.NotFound:
@@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
}
}
deviceFlowInfo.ProviderConfig.LoginHint = hint
if hint != "" {
deviceFlowInfo.SetLoginHint(hint)
}
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
return deviceFlowInfo, nil
}

View File

@@ -19,8 +19,8 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/templates"
"github.com/netbirdio/netbird/shared/management/client/common"
)
var _ OAuthFlow = &PKCEAuthorizationFlow{}
@@ -33,17 +33,67 @@ const (
defaultPKCETimeoutSeconds = 300
)
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// validatePKCEConfig validates PKCE provider configuration
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMsgFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
}
return nil
}
// PKCEAuthorizationFlow implements the OAuthFlow interface for
// the Authorization Code Flow with PKCE.
type PKCEAuthorizationFlow struct {
providerConfig internal.PKCEAuthProviderConfig
providerConfig PKCEAuthProviderConfig
state string
codeVerifier string
oAuthConfig *oauth2.Config
}
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
var availableRedirectURL string
// find the first available redirect URL
@@ -121,10 +171,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
}, nil
}
// SetLoginHint sets the login hint for the PKCE authorization flow
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
p.providerConfig.LoginHint = hint
}
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
// Once the token is received, it is converted to TokenInfo and validated before returning.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
// The method creates a timeout context internally based on info.ExpiresIn.
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
// Create timeout context based on flow expiration
timeout := time.Duration(info.ExpiresIn) * time.Second
waitCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
tokenChan := make(chan *oauth2.Token, 1)
errChan := make(chan error, 1)
@@ -135,7 +196,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
defer func() {
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
@@ -146,8 +207,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
go p.startServer(server, tokenChan, errChan)
select {
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-waitCtx.Done():
return TokenInfo{}, waitCtx.Err()
case token := <-tokenChan:
return p.parseOAuthToken(token)
case err := <-errChan:

View File

@@ -6,7 +6,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/shared/management/client/common"
)
@@ -41,7 +40,7 @@ func TestPromptLogin(t *testing.T) {
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := internal.PKCEAuthProviderConfig{
config := PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",

View File

@@ -162,6 +162,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return err
}
// Create management client once outside retry loop
mgmClient := mgm.NewClient(c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier)
defer c.statusRecorder.ClientStop()
operation := func() error {
// if context cancelled we not start new backoff cycle
@@ -180,12 +185,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}()
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
if err := mgmClient.ConnectWithoutRetry(engineCtx); err != nil {
return wrapErr(fmt.Errorf("failed connecting to Management Service: %w", err))
}
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier)
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
defer func() {
@@ -198,7 +200,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
if err != nil {
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
if errors.Is(err, mgm.ErrPermissionDenied) {
state.Set(StatusNeedsLogin)
_ = c.Stop()
return backoff.Permanent(wrapErr(err)) // unrecoverable error
@@ -320,7 +322,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
err = backoff.Retry(operation, backOff)
if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
if errors.Is(err, mgm.ErrPermissionDenied) {
state.Set(StatusNeedsLogin)
_ = c.Stop()
}
@@ -504,12 +506,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
@@ -528,7 +524,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
loginResp, err := client.Login(ctx, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {
return nil, err
}

View File

@@ -1,136 +0,0 @@
package internal
import (
"context"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
)
// DeviceAuthorizationFlow represents Device Authorization Flow information
type DeviceAuthorizationFlow struct {
Provider string
ProviderConfig DeviceAuthProviderConfig
}
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
type DeviceAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return DeviceAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return DeviceAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return DeviceAuthorizationFlow{}, err
}
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
return DeviceAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve device flow: %v", err)
return DeviceAuthorizationFlow{}, err
}
deviceAuthorizationFlow := DeviceAuthorizationFlow{
Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: DeviceAuthProviderConfig{
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
},
}
// keep compatibility with older management versions
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
}
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
if err != nil {
return DeviceAuthorizationFlow{}, err
}
return deviceAuthorizationFlow, nil
}
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.Audience == "" {
return fmt.Errorf(errorMSGFormat, "Audience")
}
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.DeviceAuthEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
}
return nil
}

View File

@@ -891,7 +891,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.DisableSSHAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
if err := e.mgmClient.SyncMeta(e.ctx, info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
@@ -1517,7 +1517,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.DisableSSHAuth,
)
netMap, err := e.mgmClient.GetNetworkMap(info)
netMap, err := e.mgmClient.GetNetworkMap(e.ctx, info)
if err != nil {
return nil, nil, false, err
}
@@ -1666,7 +1666,7 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
signalHealthy := e.signal.IsHealthy()
log.Debugf("signal health check: healthy=%t", signalHealthy)
managementHealthy := e.mgmClient.IsHealthy()
managementHealthy := e.mgmClient.IsHealthy(e.ctx)
log.Debugf("management health check: healthy=%t", managementHealthy)
stuns := slices.Clone(e.STUNs)

View File

@@ -1503,8 +1503,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
if err != nil {
return nil, err
}
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
if err != nil {
mgmtClient := mgmt.NewClient(mgmtAddr, key, false)
if err := mgmtClient.Connect(ctx); err != nil {
return nil, err
}
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
@@ -1512,13 +1512,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
return nil, err
}
publicKey, err := mgmtClient.GetServerPublicKey()
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
err = mgmtClient.Register(ctx, setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
@@ -1531,9 +1526,10 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
wgPort := 33100 + i
testAddr := fmt.Sprintf("100.64.0.%d/24", i+1)
conf := &EngineConfig{
WgIfaceName: ifaceName,
WgAddr: resp.PeerConfig.Address,
WgAddr: testAddr,
WgPrivateKey: key,
WgPort: wgPort,
MTU: iface.DefaultMTU,

View File

@@ -1,201 +0,0 @@
package internal
import (
"context"
"net/url"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// IsLoginRequired check that the server is support SSO or not
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
mgmURL := config.ManagementURL
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
if err != nil {
return false, err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", mgmURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return false, err
}
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) {
return true, nil
}
return false, err
}
// Login or register the client
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
return err
}
defer func() {
err = mgmClient.Close()
if err != nil {
cStatus, ok := status.FromError(err)
if !ok || ok && cStatus.Code() != codes.Canceled {
log.Warnf("failed to close the Management service client, err: %v", err)
}
}
}()
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return err
}
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
if err != nil {
return err
}
} else if err != nil {
return err
}
return nil
}
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return nil, err
}
var mgmTlsEnabled bool
if mgmURL.Scheme == "https" {
mgmTlsEnabled = true
}
log.Debugf("connecting to the Management service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil {
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
return nil, err
}
return mgmClient, err
}
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
}
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
info.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
}
log.Infof("peer has been successfully registered on Management Service")
return loginResp, nil
}
func isLoginNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
return true
}
return false
}
func isRegistrationNeeded(err error) bool {
if err == nil {
return false
}
s, ok := status.FromError(err)
if !ok {
return false
}
if s.Code() == codes.PermissionDenied {
return true
}
return false
}

View File

@@ -1,138 +0,0 @@
package internal
import (
"context"
"crypto/tls"
"fmt"
"net/url"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/client/common"
)
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
type PKCEAuthorizationFlow struct {
ProviderConfig PKCEAuthProviderConfig
}
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
type PKCEAuthProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
return PKCEAuthorizationFlow{}, err
}
var mgmTLSEnabled bool
if mgmURL.Scheme == "https" {
mgmTLSEnabled = true
}
log.Debugf("connecting to Management Service %s", mgmURL.String())
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
if err != nil {
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
return PKCEAuthorizationFlow{}, err
}
log.Debugf("connected to the Management service %s", mgmURL.String())
defer func() {
err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return PKCEAuthorizationFlow{}, err
}
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
return PKCEAuthorizationFlow{}, err
}
log.Errorf("failed to retrieve pkce flow: %v", err)
return PKCEAuthorizationFlow{}, err
}
authFlow := PKCEAuthorizationFlow{
ProviderConfig: PKCEAuthProviderConfig{
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
},
}
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
if err != nil {
return PKCEAuthorizationFlow{}, err
}
return authFlow, nil
}
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID")
}
if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
}
if config.AuthorizationEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
}
if config.Scope == "" {
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
}
if config.RedirectURLs == nil {
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
}
return nil
}

View File

@@ -730,8 +730,8 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
return config, err
}
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
if err != nil {
client := mgm.NewClient(newURL.Host, key, mgmTlsEnabled)
if err := client.Connect(ctx); err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, err
}
@@ -743,8 +743,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
}()
// gRPC check
_, err = client.GetServerPublicKey()
if err != nil {
if err := client.HealthCheck(ctx); err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return nil, err
}

View File

@@ -208,7 +208,13 @@ func (c *Client) IsLoginRequired() bool {
ConfigPath: c.cfgFile,
})
needsLogin, _ := internal.IsLoginRequired(ctx, cfg)
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
if err != nil {
return true // Assume login is required if we can't create auth client
}
defer authClient.Close()
needsLogin, _ := authClient.IsLoginRequired(ctx)
return needsLogin
}
@@ -240,15 +246,17 @@ func (c *Client) LoginForMobile() string {
// This could cause a potential race condition with loading the extension which need to be handled on swift side
go func() {
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
if err != nil {
return
}
jwtToken := tokenInfo.GetTokenToUse()
_ = internal.Login(ctx, cfg, "", jwtToken)
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
if err != nil {
return
}
defer authClient.Close()
_ = authClient.Login(ctx, "", jwtToken)
c.loginComplete = true
}()

View File

@@ -3,15 +3,8 @@ package NetBirdSDK
import (
"context"
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
)
@@ -71,30 +64,21 @@ func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
// is not supported and returns false without saving the configuration. For other errors return false.
func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
supportsSSO = false
err = nil
}
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
return err
}
return err
})
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}
if !supportsSSO {
return false, nil
}
if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
@@ -103,32 +87,31 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
err = authClient.Login(ctxWithValues, setupKey, "")
if err != nil {
return fmt.Errorf("login failed: %v", err)
}
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
}
func (a *Auth) Login() error {
var needsLogin bool
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()
// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
needsLogin, err := authClient.IsLoginRequired(a.ctx)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("failed to check login requirement: %v", err)
}
jwtToken := ""
@@ -136,25 +119,10 @@ func (a *Auth) Login() error {
return fmt.Errorf("Not authenticated")
}
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
err = authClient.Login(a.ctx, "", jwtToken)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}
return nil
}
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}

View File

@@ -278,14 +278,22 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
var status internal.StatusType
err := internal.Login(ctx, s.config, setupKey, jwtToken)
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
if err != nil {
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
log.Warnf("failed login: %v", err)
log.Errorf("failed to create auth client: %v", err)
return internal.StatusLoginFailed, err
}
defer authClient.Close()
var status internal.StatusType
err = authClient.Login(ctx, setupKey, jwtToken)
if err != nil {
// Check if it's an authentication error (permission denied, invalid credentials, unauthenticated)
if errors.Is(err, mgm.ErrPermissionDenied) || errors.Is(err, mgm.ErrInvalidArgument) || errors.Is(err, mgm.ErrUnauthenticated) {
log.Warnf("authentication failed: %v", err)
status = internal.StatusNeedsLogin
} else {
log.Errorf("failed login: %v", err)
log.Errorf("login failed: %v", err)
status = internal.StatusLoginFailed
}
return status, err
@@ -606,8 +614,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
s.oauthAuthFlow.waitCancel()
}
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
waitCTX, cancel := context.WithCancel(ctx)
defer cancel()
s.mutex.Lock()
@@ -1002,8 +1009,8 @@ func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profil
}
mgmTlsEnabled := config.ManagementURL.Scheme == "https"
mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, key, mgmTlsEnabled)
if err != nil {
mgmClient := mgm.NewClient(config.ManagementURL.Host, key, mgmTlsEnabled)
if err := mgmClient.Connect(ctx); err != nil {
return fmt.Errorf("connect to management server: %w", err)
}
defer func() {
@@ -1012,7 +1019,7 @@ func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profil
}
}()
return mgmClient.Logout()
return mgmClient.Logout(ctx)
}
// Status returns the daemon status

View File

@@ -4,8 +4,6 @@ import (
"context"
"io"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
@@ -14,13 +12,13 @@ import (
type Client interface {
io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
IsHealthy() bool
SyncMeta(sysInfo *system.Info) error
Logout() error
Register(ctx context.Context, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) error
Login(ctx context.Context, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(ctx context.Context) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(ctx context.Context) (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap(ctx context.Context, sysInfo *system.Info) (*proto.NetworkMap, error)
IsHealthy(ctx context.Context) bool
HealthCheck(ctx context.Context) error
SyncMeta(ctx context.Context, sysInfo *system.Info) error
Logout(ctx context.Context) error
}

View File

@@ -193,12 +193,12 @@ func TestClient_GetServerPublicKey(t *testing.T) {
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
client := NewClient(listener.Addr().String(), testKey, false)
if err := client.Connect(ctx); err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
key, err := client.getServerPublicKey(ctx)
if err != nil {
t.Error("couldn't retrieve management public key")
}
@@ -216,16 +216,12 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
client := NewClient(listener.Addr().String(), testKey, false)
if err := client.Connect(ctx); err != nil {
t.Fatal(err)
}
sysInfo := system.GetInfo(context.TODO())
_, err = client.Login(*key, sysInfo, nil, nil)
_, err = client.Login(ctx, sysInfo, nil, nil)
if err == nil {
t.Error("expecting err on unregistered login, got nil")
}
@@ -243,24 +239,16 @@ func TestClient_LoginRegistered(t *testing.T) {
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
client := NewClient(listener.Addr().String(), testKey, false)
if err := client.Connect(ctx); err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
err = client.Register(ctx, ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
if resp == nil {
t.Error("expecting non nil response, got nil")
}
}
func TestClient_Sync(t *testing.T) {
@@ -272,18 +260,13 @@ func TestClient_Sync(t *testing.T) {
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
client := NewClient(listener.Addr().String(), testKey, false)
if err := client.Connect(ctx); err != nil {
t.Fatal(err)
}
serverKey, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
err = client.Register(ctx, ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
@@ -293,13 +276,14 @@ func TestClient_Sync(t *testing.T) {
if err != nil {
t.Error(err)
}
remoteClient, err := NewClient(context.TODO(), listener.Addr().String(), remoteKey, false)
if err != nil {
remoteClient := NewClient(listener.Addr().String(), remoteKey, false)
remoteCtx := context.TODO()
if err := remoteClient.Connect(remoteCtx); err != nil {
t.Fatal(err)
}
info = system.GetInfo(context.TODO())
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
err = remoteClient.Register(remoteCtx, ValidKey, "", info, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -354,14 +338,9 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
serverAddr := lis.Addr().String()
ctx := context.Background()
testClient, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
key, err := testClient.GetServerPublicKey()
if err != nil {
t.Fatalf("error while getting server public key from testclient, %v", err)
testClient := NewClient(serverAddr, testKey, false)
if err := testClient.Connect(ctx); err != nil {
t.Fatalf("error while connecting testClient: %v", err)
}
var actualMeta *mgmtProto.PeerSystemMeta
@@ -400,7 +379,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
}
info := system.GetInfo(context.TODO())
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
err = testClient.Register(ctx, ValidKey, "", info, nil, nil)
if err != nil {
t.Errorf("error while trying to register client: %v", err)
}
@@ -489,9 +468,9 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
serverAddr := lis.Addr().String()
ctx := context.Background()
client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
client := NewClient(serverAddr, testKey, false)
if err := client.Connect(ctx); err != nil {
t.Fatalf("error while connecting testClient: %v", err)
}
expectedFlowInfo := &mgmtProto.DeviceAuthorizationFlow{
@@ -512,7 +491,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
}, nil
}
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
flowInfo, err := client.GetDeviceAuthorizationFlow(ctx)
if err != nil {
t.Error("error while retrieving device auth flow information")
}
@@ -533,9 +512,9 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
serverAddr := lis.Addr().String()
ctx := context.Background()
client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
client := NewClient(serverAddr, testKey, false)
if err := client.Connect(ctx); err != nil {
t.Fatalf("error while connecting testClient: %v", err)
}
expectedFlowInfo := &mgmtProto.PKCEAuthorizationFlow{
@@ -558,7 +537,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
}, nil
}
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey)
flowInfo, err := client.GetPKCEAuthorizationFlow(ctx)
if err != nil {
t.Error("error while retrieving pkce auth flow information")
}

View File

@@ -25,6 +25,24 @@ import (
"github.com/netbirdio/netbird/util/wsproxy"
)
// Custom management client errors that abstract away gRPC error codes
var (
// ErrPermissionDenied is returned when the server denies access to a resource
ErrPermissionDenied = errors.New("permission denied")
// ErrInvalidArgument is returned when the request contains invalid arguments
ErrInvalidArgument = errors.New("invalid argument")
// ErrUnauthenticated is returned when authentication is required
ErrUnauthenticated = errors.New("unauthenticated")
// ErrNotFound is returned when the requested resource is not found
ErrNotFound = errors.New("not found")
// ErrUnimplemented is returned when the operation is not implemented
ErrUnimplemented = errors.New("not implemented")
)
const ConnectTimeout = 10 * time.Second
const (
@@ -41,40 +59,67 @@ type ConnStateNotifier interface {
type GrpcClient struct {
key wgtypes.Key
realClient proto.ManagementServiceClient
ctx context.Context
conn *grpc.ClientConn
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
addr string
tlsEnabled bool
reconnectMutex sync.Mutex
}
// NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
// The client is not connected after creation - call Connect to establish the connection
func NewClient(addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) *GrpcClient {
return &GrpcClient{
key: ourPrivateKey,
addr: addr,
tlsEnabled: tlsEnabled,
connStateCallbackLock: sync.RWMutex{},
reconnectMutex: sync.Mutex{},
}
}
// Connect establishes a connection to the Management Service with retry logic
// Retries connection attempts with exponential backoff on failure
func (c *GrpcClient) Connect(ctx context.Context) error {
var conn *grpc.ClientConn
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
conn, err = nbgrpc.CreateConnection(ctx, c.addr, c.tlsEnabled, wsproxy.ManagementComponent)
if err != nil {
return fmt.Errorf("create connection: %w", err)
log.Warnf("failed to connect to Management Service: %v", err)
return err
}
return nil
}
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed creating connection to Management Service: %v", err)
return nil, err
if err := backoff.Retry(operation, defaultBackoff(ctx)); err != nil {
log.Errorf("failed creating connection to Management Service after retries: %v", err)
return fmt.Errorf("create connection: %w", err)
}
realClient := proto.NewManagementServiceClient(conn)
c.conn = conn
c.realClient = proto.NewManagementServiceClient(conn)
return &GrpcClient{
key: ourPrivateKey,
realClient: realClient,
ctx: ctx,
conn: conn,
connStateCallbackLock: sync.RWMutex{},
}, nil
log.Infof("connected to the Management Service at %s", c.addr)
return nil
}
// ConnectWithoutRetry establishes a connection to the Management Service without retry logic
// Performs a single connection attempt - callers should implement their own retry logic if needed
func (c *GrpcClient) ConnectWithoutRetry(ctx context.Context) error {
conn, err := nbgrpc.CreateConnection(ctx, c.addr, c.tlsEnabled, wsproxy.ManagementComponent)
if err != nil {
log.Warnf("failed to connect to Management Service: %v", err)
return fmt.Errorf("create connection: %w", err)
}
c.conn = conn
c.realClient = proto.NewManagementServiceClient(conn)
log.Debugf("connected to the Management Service at %s", c.addr)
return nil
}
// Close closes connection to the Management Service
@@ -89,19 +134,6 @@ func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallback = notifier
}
// defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 1,
Multiplier: 1.7,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// ready indicates whether the client is okay and ready to be used
// for now it just checks whether gRPC connection to the service is ready
func (c *GrpcClient) ready() bool {
@@ -122,7 +154,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return fmt.Errorf("connection to management is not ready and in %s state", connState)
}
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey(ctx)
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
@@ -177,15 +209,13 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key,
}
// GetNetworkMap return with the network map
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
serverPubKey, err := c.GetServerPublicKey()
func (c *GrpcClient) GetNetworkMap(ctx context.Context, sysInfo *system.Info) (*proto.NetworkMap, error) {
serverPubKey, err := c.getServerPublicKey(ctx)
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return nil, err
}
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
@@ -264,30 +294,32 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
}
}
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
// getServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
// This is a simple operation without retry logic - callers should handle retries at the operation level
func (c *GrpcClient) getServerPublicKey(ctx context.Context) (*wgtypes.Key, error) {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
mgmCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, fmt.Errorf("failed while getting Management Service public key")
}
serverKey, err := wgtypes.ParseKey(resp.Key)
key, err := wgtypes.ParseKey(resp.Key)
if err != nil {
return nil, err
}
return &serverKey, nil
return &key, nil
}
// IsHealthy probes the gRPC connection and returns false on errors
func (c *GrpcClient) IsHealthy() bool {
func (c *GrpcClient) IsHealthy(ctx context.Context) bool {
switch c.conn.GetState() {
case connectivity.TransientFailure:
return false
@@ -299,10 +331,10 @@ func (c *GrpcClient) IsHealthy() bool {
case connectivity.Ready:
}
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
healthCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
_, err := c.realClient.GetServerKey(healthCtx, &proto.Empty{})
if err != nil {
c.notifyDisconnected(err)
log.Warnf("health check returned: %s", err)
@@ -312,12 +344,26 @@ func (c *GrpcClient) IsHealthy() bool {
return true
}
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
// HealthCheck verifies connectivity to the management server
// Returns an error if the server is not reachable
// Internally uses getServerPublicKey to verify the connection
func (c *GrpcClient) HealthCheck(ctx context.Context) error {
_, err := c.getServerPublicKey(ctx)
return err
}
func (c *GrpcClient) login(ctx context.Context, req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
serverKey, err := c.getServerPublicKey(ctx)
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return nil, err
}
loginReq, err := encryption.EncryptMessage(*serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
@@ -325,7 +371,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
var resp *proto.EncryptedMessage
operation := func() error {
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
defer cancel()
var err error
@@ -344,14 +390,14 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
return nil
}
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
err = backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed to login to Management Service: %v", err)
return nil, err
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt login response: %s", err)
return nil, err
@@ -363,99 +409,135 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
// Returns custom errors: ErrPermissionDenied, ErrInvalidArgument, ErrUnauthenticated
func (c *GrpcClient) Register(ctx context.Context, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) error {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
_, err := c.login(ctx, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
return wrapGRPCError(err)
}
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
// Returns custom errors: ErrPermissionDenied, ErrInvalidArgument, ErrUnauthenticated
func (c *GrpcClient) Login(ctx context.Context, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
resp, err := c.login(ctx, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
return resp, wrapGRPCError(err)
}
// GetDeviceAuthorizationFlow returns a device authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
// It automatically retries with backoff and reconnection on connection errors.
// Returns custom errors: ErrNotFound, ErrUnimplemented
func (c *GrpcClient) GetDeviceAuthorizationFlow(ctx context.Context) (*proto.DeviceAuthorizationFlow, error) {
var flowInfoResp *proto.DeviceAuthorizationFlow
message := &proto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
if err != nil {
return nil, err
}
err := c.withRetry(ctx, func() error {
if !c.ready() {
return fmt.Errorf("no connection to management in order to get device authorization flow")
}
resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG},
)
if err != nil {
return nil, err
}
serverKey, err := c.getServerPublicKey(ctx)
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
}
flowInfoResp := &proto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
log.Error(errWithMSG)
return nil, errWithMSG
}
mgmCtx, cancel := context.WithTimeout(ctx, time.Second*2)
defer cancel()
return flowInfoResp, nil
message := &proto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil {
return err
}
resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG},
)
if err != nil {
return err
}
flowInfo := &proto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfo)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
log.Error(errWithMSG)
return errWithMSG
}
flowInfoResp = flowInfo
return nil
})
return flowInfoResp, wrapGRPCError(err)
}
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
// It automatically retries with backoff and reconnection on connection errors.
// Returns custom errors: ErrNotFound, ErrUnimplemented
func (c *GrpcClient) GetPKCEAuthorizationFlow(ctx context.Context) (*proto.PKCEAuthorizationFlow, error) {
var flowInfoResp *proto.PKCEAuthorizationFlow
message := &proto.PKCEAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
if err != nil {
return nil, err
}
err := c.withRetry(ctx, func() error {
if !c.ready() {
return fmt.Errorf("no connection to management in order to get pkce authorization flow")
}
resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG,
serverKey, err := c.getServerPublicKey(ctx)
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
}
mgmCtx, cancel := context.WithTimeout(ctx, time.Second*2)
defer cancel()
message := &proto.PKCEAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil {
return err
}
resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG,
})
if err != nil {
return err
}
flowInfo := &proto.PKCEAuthorizationFlow{}
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfo)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
log.Error(errWithMSG)
return errWithMSG
}
flowInfoResp = flowInfo
return nil
})
if err != nil {
return nil, err
}
flowInfoResp := &proto.PKCEAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
log.Error(errWithMSG)
return nil, errWithMSG
}
return flowInfoResp, nil
return flowInfoResp, wrapGRPCError(err)
}
// SyncMeta sends updated system metadata to the Management Service.
// It should be used if there is changes on peer posture check after initial sync.
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
func (c *GrpcClient) SyncMeta(ctx context.Context, sysInfo *system.Info) error {
if !c.ready() {
return errors.New(errMsgNoMgmtConnection)
}
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey(ctx)
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
@@ -467,7 +549,7 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
return err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
defer cancel()
_, err = c.realClient.SyncMeta(mgmCtx, &proto.EncryptedMessage{
@@ -497,13 +579,13 @@ func (c *GrpcClient) notifyConnected() {
c.connStateCallback.MarkManagementConnected()
}
func (c *GrpcClient) Logout() error {
serverKey, err := c.GetServerPublicKey()
func (c *GrpcClient) Logout(ctx context.Context) error {
serverKey, err := c.getServerPublicKey(ctx)
if err != nil {
return fmt.Errorf("get server public key: %w", err)
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*15)
mgmCtx, cancel := context.WithTimeout(ctx, time.Second*15)
defer cancel()
message := &proto.Empty{}
@@ -523,6 +605,156 @@ func (c *GrpcClient) Logout() error {
return nil
}
// reconnect closes the current connection and creates a new one
func (c *GrpcClient) reconnect(ctx context.Context) error {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
// Close existing connection
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Debugf("error closing old connection: %v", err)
}
}
// Create new connection
log.Debugf("reconnecting to Management Service %s", c.addr)
conn, err := nbgrpc.CreateConnection(ctx, c.addr, c.tlsEnabled, wsproxy.ManagementComponent)
if err != nil {
log.Errorf("failed reconnecting to Management Service %s: %v", c.addr, err)
return fmt.Errorf("reconnect: create connection: %w", err)
}
c.conn = conn
c.realClient = proto.NewManagementServiceClient(conn)
log.Debugf("successfully reconnected to Management service %s", c.addr)
return nil
}
// withRetry wraps an operation with exponential backoff retry logic
// It automatically reconnects on connection errors
func (c *GrpcClient) withRetry(ctx context.Context, operation func() error) error {
backoffSettings := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.5,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 2 * time.Minute,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
backoffSettings.Reset()
return backoff.RetryNotify(
func() error {
err := operation()
if err == nil {
return nil
}
// If it's a connection error, attempt reconnection
if isConnectionError(err) {
log.Warnf("connection error detected, attempting reconnection: %v", err)
if reconnectErr := c.reconnect(ctx); reconnectErr != nil {
log.Errorf("reconnection failed: %v", reconnectErr)
return reconnectErr
}
// Return the original error to trigger retry with the new connection
return err
}
// For authentication errors (InvalidArgument, PermissionDenied), don't retry
if isAuthenticationError(err) {
return backoff.Permanent(err)
}
return err
},
backoff.WithContext(backoffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("operation failed, retrying in %v: %v", duration, err)
},
)
}
// defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 1,
Multiplier: 1.7,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
func isConnectionError(err error) bool {
if err == nil {
return false
}
s, ok := gstatus.FromError(err)
if !ok {
return false
}
// These error codes indicate connection issues
return s.Code() == codes.Unavailable ||
s.Code() == codes.DeadlineExceeded ||
s.Code() == codes.Canceled ||
s.Code() == codes.Internal
}
// isAuthenticationError checks if the error is an authentication-related error that should not be retried
func isAuthenticationError(err error) bool {
if err == nil {
return false
}
s, ok := gstatus.FromError(err)
if !ok {
return false
}
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
}
// wrapGRPCError converts gRPC errors to custom management client errors
func wrapGRPCError(err error) error {
if err == nil {
return nil
}
// Check if it's already a custom error
if errors.Is(err, ErrPermissionDenied) ||
errors.Is(err, ErrInvalidArgument) ||
errors.Is(err, ErrUnauthenticated) ||
errors.Is(err, ErrNotFound) ||
errors.Is(err, ErrUnimplemented) {
return err
}
// Convert gRPC status errors to custom errors
s, ok := gstatus.FromError(err)
if !ok {
return err
}
switch s.Code() {
case codes.PermissionDenied:
return fmt.Errorf("%w: %s", ErrPermissionDenied, s.Message())
case codes.InvalidArgument:
return fmt.Errorf("%w: %s", ErrInvalidArgument, s.Message())
case codes.Unauthenticated:
return fmt.Errorf("%w: %s", ErrUnauthenticated, s.Message())
case codes.NotFound:
return fmt.Errorf("%w: %s", ErrNotFound, s.Message())
case codes.Unimplemented:
return fmt.Errorf("%w: %s", ErrUnimplemented, s.Message())
default:
return err
}
}
func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
if info == nil {
return nil

View File

@@ -3,8 +3,6 @@ package client
import (
"context"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
@@ -13,17 +11,21 @@ import (
type MockClient struct {
CloseFunc func() error
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(sysInfo *system.Info) error
LogoutFunc func() error
RegisterFunc func(ctx context.Context, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) error
LoginFunc func(ctx context.Context, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(ctx context.Context) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(ctx context.Context) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(ctx context.Context, sysInfo *system.Info) error
HealthCheckFunc func(ctx context.Context) error
LogoutFunc func(ctx context.Context) error
IsHealthyFunc func(ctx context.Context) bool
}
func (m *MockClient) IsHealthy() bool {
return true
func (m *MockClient) IsHealthy(ctx context.Context) bool {
if m.IsHealthyFunc == nil {
return true
}
return m.IsHealthyFunc(ctx)
}
func (m *MockClient) Close() error {
@@ -40,56 +42,56 @@ func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return m.SyncFunc(ctx, sysInfo, msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
if m.GetServerPublicKeyFunc == nil {
return nil, nil
}
return m.GetServerPublicKeyFunc()
}
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
func (m *MockClient) Register(ctx context.Context, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) error {
if m.RegisterFunc == nil {
return nil, nil
return nil
}
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
return m.RegisterFunc(ctx, setupKey, jwtToken, info, sshKey, dnsLabels)
}
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
func (m *MockClient) Login(ctx context.Context, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.LoginFunc == nil {
return nil, nil
}
return m.LoginFunc(serverKey, info, sshKey, dnsLabels)
return m.LoginFunc(ctx, info, sshKey, dnsLabels)
}
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
func (m *MockClient) GetDeviceAuthorizationFlow(ctx context.Context) (*proto.DeviceAuthorizationFlow, error) {
if m.GetDeviceAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetDeviceAuthorizationFlowFunc(serverKey)
return m.GetDeviceAuthorizationFlowFunc(ctx)
}
func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
func (m *MockClient) GetPKCEAuthorizationFlow(ctx context.Context) (*proto.PKCEAuthorizationFlow, error) {
if m.GetPKCEAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetPKCEAuthorizationFlow(serverKey)
return m.GetPKCEAuthorizationFlowFunc(ctx)
}
// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface
func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) {
func (m *MockClient) GetNetworkMap(ctx context.Context, _ *system.Info) (*proto.NetworkMap, error) {
return nil, nil
}
func (m *MockClient) SyncMeta(sysInfo *system.Info) error {
func (m *MockClient) SyncMeta(ctx context.Context, sysInfo *system.Info) error {
if m.SyncMetaFunc == nil {
return nil
}
return m.SyncMetaFunc(sysInfo)
return m.SyncMetaFunc(ctx, sysInfo)
}
func (m *MockClient) Logout() error {
func (m *MockClient) HealthCheck(ctx context.Context) error {
if m.HealthCheckFunc == nil {
return nil
}
return m.HealthCheckFunc(ctx)
}
func (m *MockClient) Logout(ctx context.Context) error {
if m.LogoutFunc == nil {
return nil
}
return m.LogoutFunc()
return m.LogoutFunc(ctx)
}