mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
[management] Propagate user groups when group propagation setting is re-enabled (#3912)
This commit is contained in:
@@ -277,29 +277,11 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
|
||||
// UpdateAccountSettings updates Account settings.
|
||||
// Only users with role UserRoleAdmin can update the account.
|
||||
// User that performs the update has to belong to the account.
|
||||
// Returns an updated Account
|
||||
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
}
|
||||
|
||||
if newSettings.PeerLoginExpiration < time.Hour {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||
}
|
||||
|
||||
if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
||||
}
|
||||
|
||||
// Returns an updated Settings
|
||||
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
@@ -309,12 +291,118 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID)
|
||||
var oldSettings *types.Settings
|
||||
var updateAccountPeers bool
|
||||
var groupChangesAffectPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var groupsUpdated bool
|
||||
|
||||
oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
|
||||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
|
||||
oldSettings.DNSDomain != newSettings.DNSDomain {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled && newSettings.GroupsPropagationEnabled {
|
||||
groupsUpdated, groupChangesAffectPeers, err = propagateUserGroupMemberships(ctx, transaction, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if updateAccountPeers || groupsUpdated {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return transaction.SaveAccountSettings(ctx, store.LockingStrengthUpdate, accountID, newSettings)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldSettings := account.Settings
|
||||
extraSettingsChanged, err := am.settingsManager.UpdateExtraSettings(ctx, accountID, userID, newSettings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.handleRoutingPeerDNSResolutionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oldSettings.DNSDomain != newSettings.DNSDomain {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil)
|
||||
}
|
||||
|
||||
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
|
||||
go am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return newSettings, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
}
|
||||
|
||||
if newSettings.PeerLoginExpiration < time.Hour {
|
||||
return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||
}
|
||||
|
||||
if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) {
|
||||
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
||||
}
|
||||
|
||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peersMap := make(map[string]*nbpeer.Peer, len(peers))
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled {
|
||||
if newSettings.RoutingPeerDNSResolutionEnabled {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionEnabled, nil)
|
||||
} else {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionDisabled, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleLazyConnectionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||
if oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled {
|
||||
if newSettings.LazyConnectionEnabled {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionEnabled, nil)
|
||||
} else {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionDisabled, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
|
||||
event := activity.AccountPeerLoginExpirationEnabled
|
||||
if !newSettings.PeerLoginExpirationEnabled {
|
||||
@@ -330,82 +418,21 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
||||
}
|
||||
|
||||
updateAccountPeers := false
|
||||
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled {
|
||||
if newSettings.RoutingPeerDNSResolutionEnabled {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionEnabled, nil)
|
||||
} else {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionDisabled, nil)
|
||||
}
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
if oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled {
|
||||
if newSettings.LazyConnectionEnabled {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionEnabled, nil)
|
||||
} else {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionDisabled, nil)
|
||||
}
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
if oldSettings.DNSDomain != newSettings.DNSDomain {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil)
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("groups propagation failed: %w", err)
|
||||
}
|
||||
|
||||
account.UpdateSettings(newSettings)
|
||||
|
||||
if updateAccountPeers {
|
||||
account.Network.Serial++
|
||||
}
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
extraSettingsChanged, err := am.settingsManager.UpdateExtraSettings(ctx, accountID, userID, newSettings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if updateAccountPeers || extraSettingsChanged {
|
||||
go am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
|
||||
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
|
||||
if newSettings.GroupsPropagationEnabled {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
|
||||
// Todo: retroactively add user groups to all peers
|
||||
} else {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
|
||||
if newSettings.PeerInactivityExpirationEnabled {
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
|
||||
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
@@ -1853,3 +1880,57 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
|
||||
// Returns true if any groups were modified, true if those updates affect peers and an error.
|
||||
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
|
||||
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
groupsMap := make(map[string]*types.Group, len(groups))
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
groupsToUpdate := make(map[string]*types.Group)
|
||||
|
||||
for _, user := range users {
|
||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, user.AutoGroups, nil)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
for _, group := range updatedGroups {
|
||||
groupsToUpdate[group.ID] = group
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupsToUpdate) == 0 {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, maps.Keys(groupsToUpdate))
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, maps.Values(groupsToUpdate))
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
return true, peersAffected, nil
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ type Manager interface {
|
||||
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)
|
||||
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error)
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||
|
||||
@@ -1805,9 +1805,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
@@ -1825,11 +1826,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
// disable expiration first
|
||||
update := peer.Copy()
|
||||
update.LoginExpirationEnabled = false
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
// enabling expiration should trigger the routine
|
||||
update.LoginExpirationEnabled = true
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1856,6 +1857,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
@@ -1919,9 +1921,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
},
|
||||
}
|
||||
// enabling PeerLoginExpirationEnabled should trigger the expiration job
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
@@ -1935,6 +1938,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
failed = waitTimeout(wg, time.Second)
|
||||
@@ -1950,13 +1954,14 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
@@ -1967,12 +1972,14 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Second,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour * 24 * 181,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
|
||||
}
|
||||
@@ -3319,3 +3326,120 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
initiatorId := "test-user"
|
||||
domain := "example.com"
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId}
|
||||
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId}
|
||||
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
})
|
||||
|
||||
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
|
||||
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
|
||||
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1))
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AutoGroups = append(user.AutoGroups, group1.ID)
|
||||
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user))
|
||||
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
|
||||
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, group.Peers, 2)
|
||||
assert.Contains(t, group.Peers, "peer1")
|
||||
assert.Contains(t, group.Peers, "peer2")
|
||||
})
|
||||
|
||||
t.Run("should update membership and account peers for used groups", func(t *testing.T) {
|
||||
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
|
||||
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2))
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AutoGroups = append(user.AutoGroups, group2.ID)
|
||||
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user))
|
||||
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{
|
||||
Name: "Group1 Policy",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"group1"},
|
||||
Destinations: []string{"group2"},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, groupsUpdated)
|
||||
assert.True(t, groupChangesAffectPeers)
|
||||
|
||||
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
|
||||
require.NoError(t, err)
|
||||
for _, group := range groups {
|
||||
assert.Len(t, group.Peers, 2)
|
||||
assert.Contains(t, group.Peers, "peer1")
|
||||
assert.Contains(t, group.Peers, "peer2")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should not update membership or account peers when no changes", func(t *testing.T) {
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
})
|
||||
|
||||
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AutoGroups = []string{"group1"}
|
||||
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user))
|
||||
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
|
||||
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
|
||||
require.NoError(t, err)
|
||||
for _, group := range groups {
|
||||
assert.Len(t, group.Peers, 2)
|
||||
assert.Contains(t, group.Peers, "peer1")
|
||||
assert.Contains(t, group.Peers, "peer2")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
||||
}
|
||||
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -138,7 +138,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings, meta)
|
||||
resp := toAccountResponse(accountID, updatedSettings, meta)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||
return account.Settings, nil
|
||||
},
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) {
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
@@ -46,9 +46,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||
}
|
||||
|
||||
accCopy := account.Copy()
|
||||
accCopy.UpdateSettings(newSettings)
|
||||
return accCopy, nil
|
||||
return newSettings, nil
|
||||
},
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
return account.Copy(), nil
|
||||
|
||||
@@ -90,7 +90,7 @@ type MockAccountManager struct {
|
||||
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
@@ -662,7 +662,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) {
|
||||
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
if am.UpdateAccountSettingsFunc != nil {
|
||||
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
|
||||
}
|
||||
|
||||
@@ -2163,6 +2163,22 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveAccountSettings stores the account settings in DB.
|
||||
func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
||||
Select("*").Where(idQueryCondition, accountID).Updates(&types.AccountSettings{Settings: settings})
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save account settings to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save account settings to store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
|
||||
@@ -72,6 +72,7 @@ type Store interface {
|
||||
DeleteAccount(ctx context.Context, account *types.Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
|
||||
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error
|
||||
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
|
||||
|
||||
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
|
||||
|
||||
@@ -1153,8 +1153,9 @@ func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbp
|
||||
if !ok {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
addUserPeersToGroup(userPeerIDMap, group)
|
||||
groupsToUpdate = append(groupsToUpdate, group)
|
||||
if changed := addUserPeersToGroup(userPeerIDMap, group); changed {
|
||||
groupsToUpdate = append(groupsToUpdate, group)
|
||||
}
|
||||
}
|
||||
|
||||
for _, gid := range groupsToRemove {
|
||||
@@ -1162,45 +1163,65 @@ func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbp
|
||||
if !ok {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
removeUserPeersFromGroup(userPeerIDMap, group)
|
||||
groupsToUpdate = append(groupsToUpdate, group)
|
||||
if changed := removeUserPeersFromGroup(userPeerIDMap, group); changed {
|
||||
groupsToUpdate = append(groupsToUpdate, group)
|
||||
}
|
||||
}
|
||||
|
||||
return groupsToUpdate, nil
|
||||
}
|
||||
|
||||
// addUserPeersToGroup adds the user's peers to the group.
|
||||
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) {
|
||||
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) bool {
|
||||
groupPeers := make(map[string]struct{}, len(group.Peers))
|
||||
for _, pid := range group.Peers {
|
||||
groupPeers[pid] = struct{}{}
|
||||
}
|
||||
|
||||
changed := false
|
||||
for pid := range userPeerIDs {
|
||||
groupPeers[pid] = struct{}{}
|
||||
if _, exists := groupPeers[pid]; !exists {
|
||||
groupPeers[pid] = struct{}{}
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
group.Peers = make([]string, 0, len(groupPeers))
|
||||
for pid := range groupPeers {
|
||||
group.Peers = append(group.Peers, pid)
|
||||
}
|
||||
|
||||
if changed {
|
||||
group.Peers = make([]string, 0, len(groupPeers))
|
||||
for pid := range groupPeers {
|
||||
group.Peers = append(group.Peers, pid)
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// removeUserPeersFromGroup removes user's peers from the group.
|
||||
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) {
|
||||
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) bool {
|
||||
// skip removing peers from group All
|
||||
if group.Name == "All" {
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
updatedPeers := make([]string, 0, len(group.Peers))
|
||||
changed := false
|
||||
|
||||
for _, pid := range group.Peers {
|
||||
if _, found := userPeerIDs[pid]; !found {
|
||||
updatedPeers = append(updatedPeers, pid)
|
||||
if _, owned := userPeerIDs[pid]; owned {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
updatedPeers = append(updatedPeers, pid)
|
||||
}
|
||||
|
||||
group.Peers = updatedPeers
|
||||
if changed {
|
||||
group.Peers = updatedPeers
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
|
||||
|
||||
Reference in New Issue
Block a user