From 4f0a3a77ad3e3fcdeab09e1945ebab8a66b94e6f Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 5 Mar 2026 14:30:31 +0100 Subject: [PATCH] [management] Avoid breaking single acc mode when switching domains (#5511) * **Bug Fixes** * Fixed domain configuration handling in single account mode to properly retrieve and apply domain settings from account data. * Improved error handling when account data is unavailable with fallback to configured default domain. * **Tests** * Added comprehensive test coverage for single account mode domain configuration scenarios, including edge cases for missing or unavailable account data. --- management/server/account.go | 36 +++++++++- management/server/account_test.go | 114 ++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 3 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 550971337..01d0eebfa 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1379,9 +1379,10 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. // We override incoming domain claims to group users under a single account. - userAuth.Domain = am.singleAccountModeDomain - userAuth.DomainCategory = types.PrivateCategory - log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") + err := am.updateUserAuthWithSingleMode(ctx, &userAuth) + if err != nil { + return "", "", err + } } accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth) @@ -1414,6 +1415,35 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u return accountID, user.Id, nil } +// updateUserAuthWithSingleMode modifies the userAuth with the single account domain, or if there is an existing account, with the domain of that account +func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Context, userAuth *auth.UserAuth) error { + userAuth.DomainCategory = types.PrivateCategory + userAuth.Domain = am.singleAccountModeDomain + + accountID, err := am.Store.GetAnyAccountID(ctx) + if err != nil { + if e, ok := status.FromError(err); !ok || e.Type() != status.NotFound { + return err + } + log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode") + return nil + } + + if accountID == "" { + log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode") + return nil + } + + domain, _, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return err + } + userAuth.Domain = domain + + log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") + return nil +} + // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. // requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager diff --git a/management/server/account_test.go b/management/server/account_test.go index 65bab6c18..a073d4fca 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/netbirdio/netbird/shared/management/status" "github.com/prometheus/client_golang/prometheus/push" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -3966,3 +3967,116 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange") } } + +func TestUpdateUserAuthWithSingleMode(t *testing.T) { + t.Run("sets defaults and overrides domain from store", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetAnyAccountID(gomock.Any()). + Return("account-1", nil) + mockStore.EXPECT(). + GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1"). + Return("real-domain.com", "private", nil) + + am := &DefaultAccountManager{ + Store: mockStore, + singleAccountModeDomain: "fallback.com", + } + + userAuth := &auth.UserAuth{} + err := am.updateUserAuthWithSingleMode(context.Background(), userAuth) + require.NoError(t, err) + assert.Equal(t, "real-domain.com", userAuth.Domain) + assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory) + }) + + t.Run("falls back to singleAccountModeDomain when account ID is empty", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetAnyAccountID(gomock.Any()). + Return("", nil) + + am := &DefaultAccountManager{ + Store: mockStore, + singleAccountModeDomain: "fallback.com", + } + + userAuth := &auth.UserAuth{} + err := am.updateUserAuthWithSingleMode(context.Background(), userAuth) + require.NoError(t, err) + assert.Equal(t, "fallback.com", userAuth.Domain) + assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory) + }) + + t.Run("falls back to singleAccountModeDomain on NotFound error", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetAnyAccountID(gomock.Any()). + Return("", status.Errorf(status.NotFound, "no accounts")) + + am := &DefaultAccountManager{ + Store: mockStore, + singleAccountModeDomain: "fallback.com", + } + + userAuth := &auth.UserAuth{} + err := am.updateUserAuthWithSingleMode(context.Background(), userAuth) + require.NoError(t, err) + assert.Equal(t, "fallback.com", userAuth.Domain) + assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory) + }) + + t.Run("propagates non-NotFound error from GetAnyAccountID", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetAnyAccountID(gomock.Any()). + Return("", status.Errorf(status.Internal, "db down")) + + am := &DefaultAccountManager{ + Store: mockStore, + singleAccountModeDomain: "fallback.com", + } + + userAuth := &auth.UserAuth{} + err := am.updateUserAuthWithSingleMode(context.Background(), userAuth) + require.Error(t, err) + assert.Contains(t, err.Error(), "db down") + // Defaults should still be set before error path + assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory) + }) + + t.Run("propagates error from GetAccountDomainAndCategory", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockStore := store.NewMockStore(ctrl) + mockStore.EXPECT(). + GetAnyAccountID(gomock.Any()). + Return("account-1", nil) + mockStore.EXPECT(). + GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1"). + Return("", "", status.Errorf(status.Internal, "query failed")) + + am := &DefaultAccountManager{ + Store: mockStore, + singleAccountModeDomain: "fallback.com", + } + + userAuth := &auth.UserAuth{} + err := am.updateUserAuthWithSingleMode(context.Background(), userAuth) + require.Error(t, err) + assert.Contains(t, err.Error(), "query failed") + }) +}