[management] Propagate user groups when group propagation setting is re-enabled (#3912)

This commit is contained in:
Bethuel Mmbaga
2025-06-11 14:32:16 +03:00
committed by GitHub
parent 75feb0da8b
commit 4ee1635baa
9 changed files with 352 additions and 111 deletions

View File

@@ -277,29 +277,11 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// UpdateAccountSettings updates Account settings.
// Only users with role UserRoleAdmin can update the account.
// User that performs the update has to belong to the account.
// Returns an updated Account
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
}
if newSettings.PeerLoginExpiration < time.Hour {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) {
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
// Returns an updated Settings
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
@@ -309,12 +291,118 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.NewPermissionDeniedError()
}
err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID)
var oldSettings *types.Settings
var updateAccountPeers bool
var groupChangesAffectPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var groupsUpdated bool
oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return err
}
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
return err
}
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain {
updateAccountPeers = true
}
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled && newSettings.GroupsPropagationEnabled {
groupsUpdated, groupChangesAffectPeers, err = propagateUserGroupMemberships(ctx, transaction, accountID)
if err != nil {
return err
}
}
if updateAccountPeers || groupsUpdated {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
}
return transaction.SaveAccountSettings(ctx, store.LockingStrengthUpdate, accountID, newSettings)
})
if err != nil {
return nil, err
}
oldSettings := account.Settings
extraSettingsChanged, err := am.settingsManager.UpdateExtraSettings(ctx, accountID, userID, newSettings.Extra)
if err != nil {
return nil, err
}
am.handleRoutingPeerDNSResolutionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
return nil, err
}
if oldSettings.DNSDomain != newSettings.DNSDomain {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil)
}
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
go am.UpdateAccountPeers(ctx, accountID)
}
return newSettings, nil
}
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
}
if newSettings.PeerLoginExpiration < time.Hour {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) {
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
if err != nil {
return err
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID)
}
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled {
if newSettings.RoutingPeerDNSResolutionEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionEnabled, nil)
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionDisabled, nil)
}
}
}
func (am *DefaultAccountManager) handleLazyConnectionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
if oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled {
if newSettings.LazyConnectionEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionEnabled, nil)
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionDisabled, nil)
}
}
}
func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
event := activity.AccountPeerLoginExpirationEnabled
if !newSettings.PeerLoginExpirationEnabled {
@@ -330,82 +418,21 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
}
updateAccountPeers := false
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled {
if newSettings.RoutingPeerDNSResolutionEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionEnabled, nil)
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionDisabled, nil)
}
updateAccountPeers = true
}
if oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled {
if newSettings.LazyConnectionEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionEnabled, nil)
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionDisabled, nil)
}
updateAccountPeers = true
}
if oldSettings.DNSDomain != newSettings.DNSDomain {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil)
updateAccountPeers = true
}
err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, err
}
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, fmt.Errorf("groups propagation failed: %w", err)
}
account.UpdateSettings(newSettings)
if updateAccountPeers {
account.Network.Serial++
}
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
extraSettingsChanged, err := am.settingsManager.UpdateExtraSettings(ctx, accountID, userID, newSettings.Extra)
if err != nil {
return nil, err
}
if updateAccountPeers || extraSettingsChanged {
go am.UpdateAccountPeers(ctx, accountID)
}
return account, nil
}
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
if newSettings.GroupsPropagationEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
// Todo: retroactively add user groups to all peers
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
}
}
return nil
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
@@ -1853,3 +1880,57 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return account, nil
}
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
// Returns true if any groups were modified, true if those updates affect peers and an error.
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, false, err
}
groupsMap := make(map[string]*types.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, false, err
}
groupsToUpdate := make(map[string]*types.Group)
for _, user := range users {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id)
if err != nil {
return false, false, err
}
updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, user.AutoGroups, nil)
if err != nil {
return false, false, err
}
for _, group := range updatedGroups {
groupsToUpdate[group.ID] = group
groupsMap[group.ID] = group
}
}
if len(groupsToUpdate) == 0 {
return false, false, nil
}
peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, maps.Keys(groupsToUpdate))
if err != nil {
return false, false, err
}
err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, maps.Values(groupsToUpdate))
if err != nil {
return false, false, err
}
return true, peersAffected, nil
}

View File

@@ -88,7 +88,7 @@ type Manager interface {
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)

View File

@@ -1805,9 +1805,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected")
account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
@@ -1825,11 +1826,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
// disable expiration first
update := peer.Copy()
update.LoginExpirationEnabled = false
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "unable to update peer")
// enabling expiration should trigger the routine
update.LoginExpirationEnabled = true
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "unable to update peer")
failed := waitTimeout(wg, time.Second)
@@ -1856,6 +1857,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
@@ -1919,9 +1921,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
},
}
// enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
@@ -1935,6 +1938,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
failed = waitTimeout(wg, time.Second)
@@ -1950,13 +1954,14 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")
@@ -1967,12 +1972,14 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
}
@@ -3319,3 +3326,120 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
})
})
}
func TestPropagateUserGroupMemberships(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err)
ctx := context.Background()
initiatorId := "test-user"
domain := "example.com"
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
require.NoError(t, err)
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1)
require.NoError(t, err)
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2)
require.NoError(t, err)
t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
})
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group1.ID)
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.True(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID)
require.NoError(t, err)
assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1")
assert.Contains(t, group.Peers, "peer2")
})
t.Run("should update membership and account peers for used groups", func(t *testing.T) {
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group2.ID)
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user))
_, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{
Name: "Group1 Policy",
AccountID: account.Id,
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"group1"},
Destinations: []string{"group2"},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
},
}, true)
require.NoError(t, err)
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.True(t, groupsUpdated)
assert.True(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
require.NoError(t, err)
for _, group := range groups {
assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1")
assert.Contains(t, group.Peers, "peer2")
}
})
t.Run("should not update membership or account peers when no changes", func(t *testing.T) {
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
})
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
require.NoError(t, err)
user.AutoGroups = []string{"group1"}
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
require.NoError(t, err)
for _, group := range groups {
assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1")
assert.Contains(t, group.Peers, "peer2")
}
})
}

View File

@@ -126,7 +126,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
}
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -138,7 +138,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
return
}
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings, meta)
resp := toAccountResponse(accountID, updatedSettings, meta)
util.WriteJSONObject(r.Context(), w, &resp)
}

View File

@@ -36,7 +36,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
return account.Settings, nil
},
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) {
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -46,9 +46,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
accCopy := account.Copy()
accCopy.UpdateSettings(newSettings)
return accCopy, nil
return newSettings, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
return account.Copy(), nil

View File

@@ -90,7 +90,7 @@ type MockAccountManager struct {
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error)
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
@@ -662,7 +662,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
}
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) {
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
if am.UpdateAccountSettingsFunc != nil {
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
}

View File

@@ -2163,6 +2163,22 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre
return nil
}
// SaveAccountSettings stores the account settings in DB.
func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
Select("*").Where(idQueryCondition, accountID).Updates(&types.AccountSettings{Settings: settings})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save account settings to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save account settings to store")
}
if result.RowsAffected == 0 {
return status.NewAccountNotFoundError(accountID)
}
return nil
}
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) {
tx := s.db
if lockStrength != LockingStrengthNone {

View File

@@ -72,6 +72,7 @@ type Store interface {
DeleteAccount(ctx context.Context, account *types.Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)

View File

@@ -1153,8 +1153,9 @@ func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbp
if !ok {
return nil, errors.New("group not found")
}
addUserPeersToGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
if changed := addUserPeersToGroup(userPeerIDMap, group); changed {
groupsToUpdate = append(groupsToUpdate, group)
}
}
for _, gid := range groupsToRemove {
@@ -1162,45 +1163,65 @@ func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbp
if !ok {
return nil, errors.New("group not found")
}
removeUserPeersFromGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
if changed := removeUserPeersFromGroup(userPeerIDMap, group); changed {
groupsToUpdate = append(groupsToUpdate, group)
}
}
return groupsToUpdate, nil
}
// addUserPeersToGroup adds the user's peers to the group.
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) {
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) bool {
groupPeers := make(map[string]struct{}, len(group.Peers))
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
}
changed := false
for pid := range userPeerIDs {
groupPeers[pid] = struct{}{}
if _, exists := groupPeers[pid]; !exists {
groupPeers[pid] = struct{}{}
changed = true
}
}
group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
if changed {
group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
}
return changed
}
// removeUserPeersFromGroup removes user's peers from the group.
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) {
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) bool {
// skip removing peers from group All
if group.Name == "All" {
return
return false
}
updatedPeers := make([]string, 0, len(group.Peers))
changed := false
for _, pid := range group.Peers {
if _, found := userPeerIDs[pid]; !found {
updatedPeers = append(updatedPeers, pid)
if _, owned := userPeerIDs[pid]; owned {
changed = true
continue
}
updatedPeers = append(updatedPeers, pid)
}
group.Peers = updatedPeers
if changed {
group.Peers = updatedPeers
}
return changed
}
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {