mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
[management] Avoid breaking single acc mode when switching domains (#5511)
* **Bug Fixes** * Fixed domain configuration handling in single account mode to properly retrieve and apply domain settings from account data. * Improved error handling when account data is unavailable with fallback to configured default domain. * **Tests** * Added comprehensive test coverage for single account mode domain configuration scenarios, including edge cases for missing or unavailable account data.
This commit is contained in:
@@ -1379,9 +1379,10 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
|||||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||||
// This section is mostly related to self-hosted installations.
|
// This section is mostly related to self-hosted installations.
|
||||||
// We override incoming domain claims to group users under a single account.
|
// We override incoming domain claims to group users under a single account.
|
||||||
userAuth.Domain = am.singleAccountModeDomain
|
err := am.updateUserAuthWithSingleMode(ctx, &userAuth)
|
||||||
userAuth.DomainCategory = types.PrivateCategory
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
return "", "", err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
|
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
|
||||||
@@ -1414,6 +1415,35 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
|||||||
return accountID, user.Id, nil
|
return accountID, user.Id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateUserAuthWithSingleMode modifies the userAuth with the single account domain, or if there is an existing account, with the domain of that account
|
||||||
|
func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Context, userAuth *auth.UserAuth) error {
|
||||||
|
userAuth.DomainCategory = types.PrivateCategory
|
||||||
|
userAuth.Domain = am.singleAccountModeDomain
|
||||||
|
|
||||||
|
accountID, err := am.Store.GetAnyAccountID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if e, ok := status.FromError(err); !ok || e.Type() != status.NotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountID == "" {
|
||||||
|
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
domain, _, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
userAuth.Domain = domain
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||||
// and propagates changes to peers if group propagation is enabled.
|
// and propagates changes to peers if group propagation is enabled.
|
||||||
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
"github.com/prometheus/client_golang/prometheus/push"
|
"github.com/prometheus/client_golang/prometheus/push"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -3966,3 +3967,116 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
|
|||||||
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
|
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateUserAuthWithSingleMode(t *testing.T) {
|
||||||
|
t.Run("sets defaults and overrides domain from store", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("account-1", nil)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
|
||||||
|
Return("real-domain.com", "private", nil)
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "real-domain.com", userAuth.Domain)
|
||||||
|
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to singleAccountModeDomain when account ID is empty", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("", nil)
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "fallback.com", userAuth.Domain)
|
||||||
|
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to singleAccountModeDomain on NotFound error", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("", status.Errorf(status.NotFound, "no accounts"))
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "fallback.com", userAuth.Domain)
|
||||||
|
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates non-NotFound error from GetAnyAccountID", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("", status.Errorf(status.Internal, "db down"))
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "db down")
|
||||||
|
// Defaults should still be set before error path
|
||||||
|
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagates error from GetAccountDomainAndCategory", func(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAnyAccountID(gomock.Any()).
|
||||||
|
Return("account-1", nil)
|
||||||
|
mockStore.EXPECT().
|
||||||
|
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
|
||||||
|
Return("", "", status.Errorf(status.Internal, "query failed"))
|
||||||
|
|
||||||
|
am := &DefaultAccountManager{
|
||||||
|
Store: mockStore,
|
||||||
|
singleAccountModeDomain: "fallback.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth := &auth.UserAuth{}
|
||||||
|
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "query failed")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user