From 552dc605479eb92822df107eb78e0c1326290bc2 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 1 Aug 2025 12:22:07 +0200 Subject: [PATCH] [management] migrate group peers into seperate table (#4096) --- management/server/account.go | 95 +++--- management/server/account/manager.go | 6 +- management/server/account_test.go | 31 +- management/server/dns_test.go | 8 +- management/server/group.go | 230 +++++++++++-- management/server/group_test.go | 323 ++++++++++++++++-- .../http/handlers/groups/groups_handler.go | 4 +- management/server/migration/migration.go | 64 ++++ management/server/mock_server/account_mock.go | 28 ++ management/server/nameserver_test.go | 24 +- management/server/peer.go | 56 ++- management/server/peer_test.go | 33 +- management/server/policy_test.go | 10 +- management/server/posture_checks_test.go | 25 +- management/server/route_test.go | 21 +- management/server/setupkey_test.go | 16 +- management/server/store/sql_store.go | 304 +++++++++++++---- management/server/store/sql_store_test.go | 113 ++++-- management/server/store/store.go | 23 +- management/server/types/account.go | 2 +- management/server/types/group.go | 32 +- management/server/types/setupkey.go | 2 +- management/server/user.go | 106 +----- management/server/user_test.go | 4 +- 24 files changed, 1139 insertions(+), 421 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 52b625da1..d392cd0b9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1368,7 +1368,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth return nil } - if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil { + if err = transaction.CreateGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil { return fmt.Errorf("error saving groups: %w", err) } @@ -1382,28 +1382,22 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) - if err != nil { - return fmt.Errorf("error getting account groups: %w", err) - } - - groupsMap := make(map[string]*types.Group, len(groups)) - for _, group := range groups { - groupsMap[group.ID] = group - } - peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId) if err != nil { return fmt.Errorf("error getting user peers: %w", err) } - updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) - if err != nil { - return fmt.Errorf("error modifying user peers in groups: %w", err) - } - - if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, updatedGroups); err != nil { - return fmt.Errorf("error saving groups: %w", err) + for _, peer := range peers { + for _, g := range addNewGroups { + if err := transaction.AddPeerToGroup(ctx, userAuth.AccountId, peer.ID, g); err != nil { + return fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, g, err) + } + } + for _, g := range removeOldGroups { + if err := transaction.RemovePeerFromGroup(ctx, peer.ID, g); err != nil { + return fmt.Errorf("error removing peer %s from group %s: %w", peer.ID, g, err) + } + } } if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil { @@ -1971,53 +1965,56 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc // propagateUserGroupMemberships propagates all account users' group memberships to their peers. // Returns true if any groups were modified, true if those updates affect peers and an error. func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { - groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return false, false, err - } - - groupsMap := make(map[string]*types.Group, len(groups)) - for _, group := range groups { - groupsMap[group.ID] = group - } - users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { return false, false, err } - groupsToUpdate := make(map[string]*types.Group) + accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return false, false, fmt.Errorf("error getting account group peers: %w", err) + } + accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return false, false, fmt.Errorf("error getting account groups: %w", err) + } + + for _, group := range accountGroups { + if _, exists := accountGroupPeers[group.ID]; !exists { + accountGroupPeers[group.ID] = make(map[string]struct{}) + } + } + + updatedGroups := []string{} for _, user := range users { userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id) if err != nil { return false, false, err } - updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, user.AutoGroups, nil) - if err != nil { - return false, false, err - } - - for _, group := range updatedGroups { - groupsToUpdate[group.ID] = group - groupsMap[group.ID] = group + for _, peer := range userPeers { + for _, groupID := range user.AutoGroups { + if _, exists := accountGroupPeers[groupID]; !exists { + // we do not wanna create the groups here + log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID) + continue + } + if _, exists := accountGroupPeers[groupID][peer.ID]; exists { + continue + } + if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil { + return false, false, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err) + } + updatedGroups = append(updatedGroups, groupID) + } } } - if len(groupsToUpdate) == 0 { - return false, false, nil - } - - peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, maps.Keys(groupsToUpdate)) + peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, updatedGroups) if err != nil { - return false, false, err + return false, false, fmt.Errorf("error checking if group changes affect peers: %w", err) } - err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, maps.Values(groupsToUpdate)) - if err != nil { - return false, false, err - } - - return true, peersAffected, nil + return len(updatedGroups) > 0, peersAffected, nil } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 8c7e95e3d..0cd1c6637 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -62,8 +62,10 @@ type Manager interface { GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) - SaveGroup(ctx context.Context, accountID, userID string, group *types.Group, create bool) error - SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group, create bool) error + CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error + UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error + CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error + UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error diff --git a/management/server/account_test.go b/management/server/account_test.go index b65dffe6c..1dd74104b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1159,7 +1159,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { Name: "GroupA", Peers: []string{}, } - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { + if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil { t.Errorf("save group: %v", err) return } @@ -1194,7 +1194,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { }() group.Peers = []string{peer1.ID, peer2.ID, peer3.ID} - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { + if err := manager.UpdateGroup(context.Background(), account.Id, userID, &group); err != nil { t.Errorf("save group: %v", err) return } @@ -1240,11 +1240,12 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { manager, account, peer1, peer2, _ := setupNetworkMapTest(t) group := types.Group{ - ID: "groupA", - Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID}, + AccountID: account.Id, + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, } - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { + if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil { t.Errorf("save group: %v", err) return } @@ -1292,7 +1293,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { Name: "GroupA", Peers: []string{peer1.ID, peer3.ID}, } - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { + if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil { t.Errorf("save group: %v", err) return } @@ -1343,11 +1344,11 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }, true) + }) require.NoError(t, err, "failed to save group") @@ -1672,9 +1673,10 @@ func TestAccount_Copy(t *testing.T) { }, Groups: map[string]*types.Group{ "group1": { - ID: "group1", - Peers: []string{"peer1"}, - Resources: []types.Resource{}, + ID: "group1", + Peers: []string{"peer1"}, + Resources: []types.Resource{}, + GroupPeers: []types.GroupPeer{}, }, }, Policies: []*types.Policy{ @@ -2616,6 +2618,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { } func TestAccount_SetJWTGroups(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", "postgres") manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") @@ -3360,7 +3363,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) { group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id} - require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1)) + require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group1)) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) require.NoError(t, err) @@ -3382,7 +3385,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { t.Run("should update membership and account peers for used groups", func(t *testing.T) { group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id} - require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2)) + require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group2)) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) require.NoError(t, err) diff --git a/management/server/dns_test.go b/management/server/dns_test.go index f2295450f..2af07d8e4 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -495,7 +495,7 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { func TestDNSAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -506,7 +506,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { Name: "GroupB", Peers: []string{}, }, - }, true) + }) assert.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -562,11 +562,11 @@ func TestDNSAccountPeersUpdate(t *testing.T) { // Creating DNS settings with groups that have peers should update account peers and send peer update t.Run("creating dns setting with used groups", func(t *testing.T) { - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }, true) + }) assert.NoError(t, err) done := make(chan struct{}) diff --git a/management/server/group.go b/management/server/group.go index 130a67145..95bed7d18 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -65,22 +65,144 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) } -// SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group, create bool) error { +// CreateGroup object of the peers +func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}, create) + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var eventsToStore []func() + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err + } + + newGroup.AccountID = accountID + + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + if err := transaction.CreateGroup(ctx, store.LockingStrengthUpdate, newGroup); err != nil { + return status.Errorf(status.Internal, "failed to create group: %v", err) + } + + for _, peerID := range newGroup.Peers { + if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil { + return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err) + } + } + return nil + }) + if err != nil { + return err + } + + for _, storeEvent := range eventsToStore { + storeEvent() + } + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil } -// SaveGroups adds new groups to the account. +// UpdateGroup object of the peers +func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var eventsToStore []func() + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err + } + + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) + if err != nil { + return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID) + } + + peersToAdd := util.Difference(newGroup.Peers, oldGroup.Peers) + peersToRemove := util.Difference(oldGroup.Peers, newGroup.Peers) + + for _, peerID := range peersToAdd { + if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil { + return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err) + } + } + for _, peerID := range peersToRemove { + if err := transaction.RemovePeerFromGroup(ctx, peerID, newGroup.ID); err != nil { + return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, newGroup.ID, err) + } + } + + newGroup.AccountID = accountID + + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, newGroup) + }) + if err != nil { + return err + } + + for _, storeEvent := range eventsToStore { + storeEvent() + } + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// CreateGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error { - operation := operations.Create - if !create { - operation = operations.Update - } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operation) +// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that. +func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create) if err != nil { return status.NewPermissionValidationError(err) } @@ -116,7 +238,65 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return err } - return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) + return transaction.CreateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) + }) + if err != nil { + return err + } + + for _, storeEvent := range eventsToStore { + storeEvent() + } + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// UpdateGroups updates groups in the account. +// Note: This function does not acquire the global lock. +// It is the caller's responsibility to ensure proper locking is in place before invoking this method. +// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that. +func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var eventsToStore []func() + var groupsToSave []*types.Group + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err + } + + newGroup.AccountID = accountID + groupsToSave = append(groupsToSave, newGroup) + groupIDs = append(groupIDs, newGroup.ID) + + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.UpdateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) }) if err != nil { return err @@ -265,20 +445,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - var group *types.Group var updateAccountPeers bool var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) - if err != nil { - return err - } - - if updated := group.AddPeer(peerID); !updated { - return nil - } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) if err != nil { return err @@ -288,7 +458,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + return transaction.AddPeerToGroup(ctx, accountID, peerID, groupID) }) if err != nil { return err @@ -329,7 +499,7 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID return err } - return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group) }) if err != nil { return err @@ -347,20 +517,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - var group *types.Group var updateAccountPeers bool var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) - if err != nil { - return err - } - - if updated := group.RemovePeer(peerID); !updated { - return nil - } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) if err != nil { return err @@ -370,7 +530,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } - return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + return transaction.RemovePeerFromGroup(ctx, peerID, groupID) }) if err != nil { return err @@ -411,7 +571,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun return err } - return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group) }) if err != nil { return err diff --git a/management/server/group_test.go b/management/server/group_test.go index 631fe3a71..51069dc56 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -2,14 +2,20 @@ package server import ( "context" + "encoding/binary" "errors" "fmt" + "net" "net/netip" + "strconv" + "sync" "testing" "time" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/groups" @@ -18,8 +24,10 @@ import ( "github.com/netbirdio/netbird/management/server/networks/routers" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + peer2 "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -40,7 +48,8 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { } for _, group := range account.Groups { group.Issued = types.GroupIssuedIntegration - err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true) + group.ID = uuid.New().String() + err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration) } @@ -48,7 +57,8 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { for _, group := range account.Groups { group.Issued = types.GroupIssuedJWT - err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true) + group.ID = uuid.New().String() + err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { t.Errorf("should allow to create %s groups", types.GroupIssuedJWT) } @@ -56,7 +66,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { for _, group := range account.Groups { group.Issued = types.GroupIssuedAPI group.ID = "" - err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true) + err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group) if err == nil { t.Errorf("should not create api group with the same name, %s", group.Name) } @@ -162,7 +172,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { } } - err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups, true) + err = manager.CreateGroups(context.Background(), account.Id, groupAdminUserID, groups) assert.NoError(t, err, "Failed to save test groups") testCases := []struct { @@ -382,13 +392,13 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t return nil, nil, err } - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute, true) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2, true) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups, true) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies, true) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys, true) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers, true) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration, true) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForUsers) + _ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration) acc, err := am.Store.GetAccount(context.Background(), account.Id) if err != nil { @@ -400,7 +410,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t func TestGroupAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + g := []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -426,8 +436,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { Name: "GroupE", Peers: []string{peer2.ID}, }, - }, true) - assert.NoError(t, err) + } + for _, group := range g { + err := manager.CreateGroup(context.Background(), account.Id, userID, group) + assert.NoError(t, err) + } updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { @@ -442,11 +455,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1.ID, peer2.ID}, - }, true) + }) assert.NoError(t, err) select { @@ -513,7 +526,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { }) // adding a group to policy - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, Rules: []*types.PolicyRule{ { @@ -535,11 +548,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, - }, true) + }) assert.NoError(t, err) select { @@ -604,11 +617,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1.ID, peer3.ID}, - }, true) + }) assert.NoError(t, err) select { @@ -645,11 +658,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }, true) + }) assert.NoError(t, err) select { @@ -672,11 +685,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupD", Name: "GroupD", Peers: []string{peer1.ID}, - }, true) + }) assert.NoError(t, err) select { @@ -719,11 +732,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupE", Name: "GroupE", Peers: []string{peer2.ID, peer3.ID}, - }, true) + }) assert.NoError(t, err) select { @@ -733,3 +746,259 @@ func TestGroupAccountPeersUpdate(t *testing.T) { } }) } + +func Test_AddPeerToGroup(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + accountID := "testaccount" + userID := "testuser" + + acc, err := createAccount(manager, accountID, userID, "domain.com") + if err != nil { + t.Fatal("error creating account") + return + } + + const totalPeers = 1000 + + var wg sync.WaitGroup + errs := make(chan error, totalPeers) + start := make(chan struct{}) + for i := 0; i < totalPeers; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + <-start + + err = manager.Store.AddPeerToGroup(context.Background(), accountID, strconv.Itoa(i), acc.GroupsG[0].ID) + if err != nil { + errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err) + return + } + + }(i) + } + startTime := time.Now() + + close(start) + wg.Wait() + close(errs) + + t.Logf("time since start: %s", time.Since(startTime)) + + for err := range errs { + t.Fatal(err) + } + + account, err := manager.Store.GetAccount(context.Background(), accountID) + if err != nil { + t.Fatalf("Failed to get account %s: %v", accountID, err) + } + + assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers)) +} + +func Test_AddPeerToAll(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + accountID := "testaccount" + userID := "testuser" + + _, err = createAccount(manager, accountID, userID, "domain.com") + if err != nil { + t.Fatal("error creating account") + return + } + + const totalPeers = 1000 + + var wg sync.WaitGroup + errs := make(chan error, totalPeers) + start := make(chan struct{}) + for i := 0; i < totalPeers; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + <-start + + err = manager.Store.AddPeerToAllGroup(context.Background(), accountID, strconv.Itoa(i)) + if err != nil { + errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err) + return + } + + }(i) + } + startTime := time.Now() + + close(start) + wg.Wait() + close(errs) + + t.Logf("time since start: %s", time.Since(startTime)) + + for err := range errs { + t.Fatal(err) + } + + account, err := manager.Store.GetAccount(context.Background(), accountID) + if err != nil { + t.Fatalf("Failed to get account %s: %v", accountID, err) + } + + assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers)) +} + +func Test_AddPeerAndAddToAll(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + accountID := "testaccount" + userID := "testuser" + + _, err = createAccount(manager, accountID, userID, "domain.com") + if err != nil { + t.Fatal("error creating account") + return + } + + const totalPeers = 1000 + + var wg sync.WaitGroup + errs := make(chan error, totalPeers) + start := make(chan struct{}) + for i := 0; i < totalPeers; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + <-start + + peer := &peer2.Peer{ + ID: strconv.Itoa(i), + AccountID: accountID, + DNSLabel: "peer" + strconv.Itoa(i), + IP: uint32ToIP(uint32(i)), + } + + err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error { + err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) + if err != nil { + return fmt.Errorf("AddPeer failed for peer %d: %w", i, err) + } + err = transaction.AddPeerToAllGroup(context.Background(), accountID, peer.ID) + if err != nil { + return fmt.Errorf("AddPeer failed for peer %d: %w", i, err) + } + return nil + }) + if err != nil { + t.Errorf("AddPeer failed for peer %d: %v", i, err) + return + } + }(i) + } + startTime := time.Now() + + close(start) + wg.Wait() + close(errs) + + t.Logf("time since start: %s", time.Since(startTime)) + + for err := range errs { + t.Fatal(err) + } + + account, err := manager.Store.GetAccount(context.Background(), accountID) + if err != nil { + t.Fatalf("Failed to get account %s: %v", accountID, err) + } + + assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers)) + assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers)) +} + +func uint32ToIP(n uint32) net.IP { + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, n) + return ip +} + +func Test_IncrementNetworkSerial(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + accountID := "testaccount" + userID := "testuser" + + _, err = createAccount(manager, accountID, userID, "domain.com") + if err != nil { + t.Fatal("error creating account") + return + } + + const totalPeers = 1000 + + var wg sync.WaitGroup + errs := make(chan error, totalPeers) + start := make(chan struct{}) + for i := 0; i < totalPeers; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + <-start + + err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error { + err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID) + if err != nil { + return fmt.Errorf("failed to get account %s: %v", accountID, err) + } + return nil + }) + if err != nil { + t.Errorf("AddPeer failed for peer %d: %v", i, err) + return + } + }(i) + } + startTime := time.Now() + + close(start) + wg.Wait() + close(errs) + + t.Logf("time since start: %s", time.Since(startTime)) + + for err := range errs { + t.Fatal(err) + } + + account, err := manager.Store.GetAccount(context.Background(), accountID) + if err != nil { + t.Fatalf("Failed to get account %s: %v", accountID, err) + } + + assert.Equal(t, totalPeers, int(account.Network.Serial), "Expected %d serial increases in account %s, got %d", totalPeers, accountID, account.Network.Serial) +} diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index 3ae833dc0..bede652f5 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -143,7 +143,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { IntegrationReference: existingGroup.IntegrationReference, } - if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, false); err != nil { + if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil { log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err) util.WriteError(r.Context(), err, w) return @@ -203,7 +203,7 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { Issued: types.GroupIssuedAPI, } - err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, true) + err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index c2f1a5abf..88af9a58f 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -39,6 +39,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f return nil } + if !db.Migrator().HasColumn(&model, fieldName) { + log.WithContext(ctx).Debugf("Table for %T does not have column %s, no migration needed", model, fieldName) + return nil + } + stmt := &gorm.Statement{DB: db} err := stmt.Parse(model) if err != nil { @@ -422,3 +427,62 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName) return nil } + +func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName string, mapperFunc func(accountID string, id string, value string) any) error { + var model T + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model) + return nil + } + + stmt := &gorm.Statement{DB: db} + err := stmt.Parse(&model) + if err != nil { + return fmt.Errorf("parse model: %w", err) + } + tableName := stmt.Schema.Table + + if !db.Migrator().HasColumn(&model, columnName) { + log.WithContext(ctx).Debugf("column %s does not exist in table %s, no migration needed", columnName, tableName) + return nil + } + + if err := db.Transaction(func(tx *gorm.DB) error { + var rows []map[string]any + if err := tx.Table(tableName).Select("id", "account_id", columnName).Find(&rows).Error; err != nil { + return fmt.Errorf("find rows: %w", err) + } + + for _, row := range rows { + jsonValue, ok := row[columnName].(string) + if !ok || jsonValue == "" { + continue + } + + var data []string + if err := json.Unmarshal([]byte(jsonValue), &data); err != nil { + return fmt.Errorf("unmarshal json: %w", err) + } + + for _, value := range data { + if err := tx.Create( + mapperFunc(row["account_id"].(string), row["id"].(string), value), + ).Error; err != nil { + return fmt.Errorf("failed to insert id %v: %w", row["id"], err) + } + } + } + + if err := tx.Migrator().DropColumn(&model, columnName); err != nil { + return fmt.Errorf("drop column %s: %w", columnName, err) + } + + return nil + }); err != nil { + return err + } + + log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName) + return nil +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index a16e3652c..8c8fd19c9 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -124,6 +124,34 @@ type MockAccountManager struct { BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) } +func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { + if am.SaveGroupFunc != nil { + return am.SaveGroupFunc(ctx, accountID, userID, group, true) + } + return status.Errorf(codes.Unimplemented, "method CreateGroup is not implemented") +} + +func (am *MockAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { + if am.SaveGroupFunc != nil { + return am.SaveGroupFunc(ctx, accountID, userID, group, false) + } + return status.Errorf(codes.Unimplemented, "method UpdateGroup is not implemented") +} + +func (am *MockAccountManager) CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error { + if am.SaveGroupsFunc != nil { + return am.SaveGroupsFunc(ctx, accountID, userID, newGroups, true) + } + return status.Errorf(codes.Unimplemented, "method CreateGroups is not implemented") +} + +func (am *MockAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error { + if am.SaveGroupsFunc != nil { + return am.SaveGroupsFunc(ctx, accountID, userID, newGroups, false) + } + return status.Errorf(codes.Unimplemented, "method UpdateGroups is not implemented") +} + func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { if am.UpdateAccountPeersFunc != nil { am.UpdateAccountPeersFunc(ctx, accountID) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 25eb03b83..959e7856a 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -980,18 +980,18 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { var newNameServerGroupA *nbdns.NameServerGroup var newNameServerGroupB *nbdns.NameServerGroup - err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ - { - ID: "groupA", - Name: "GroupA", - Peers: []string{}, - }, - { - ID: "groupB", - Name: "GroupB", - Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }, - }, true) + err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{}, + }) + assert.NoError(t, err) + + err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ + ID: "groupB", + Name: "GroupB", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) assert.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) diff --git a/management/server/peer.go b/management/server/peer.go index 3c40c6bb6..f954369d8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -374,12 +374,20 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { - return err + if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { + return fmt.Errorf("failed to remove peer from groups: %w", err) } eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) - return err + if err != nil { + return fmt.Errorf("failed to delete peer: %w", err) + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil }) if err != nil { return err @@ -478,7 +486,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var newPeer *nbpeer.Peer - var updateAccountPeers bool var setupKeyID string var setupKeyName string @@ -615,20 +622,20 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return err } - err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) - if err != nil { - return fmt.Errorf("failed adding peer to All group: %w", err) - } - if len(groupsToAdd) > 0 { for _, g := range groupsToAdd { - err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g) + err = transaction.AddPeerToGroup(ctx, newPeer.AccountID, newPeer.ID, g) if err != nil { return err } } } + err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID) + if err != nil { + return fmt.Errorf("failed adding peer to All group: %w", err) + } + if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { @@ -678,7 +685,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err) } - updateAccountPeers, err = isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID) + updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID) if err != nil { updateAccountPeers = true } @@ -1021,7 +1028,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is }() if isRequiresApproval { - network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) + network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, nil, nil, err } @@ -1523,17 +1530,7 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p // getPeerGroupIDs returns the IDs of the groups that the peer is part of. func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { - groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID) - if err != nil { - return nil, err - } - - groupIDs := make([]string, 0, len(groups)) - for _, group := range groups { - groupIDs = append(groupIDs, group.ID) - } - - return groupIDs, err + return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID) } // IsPeerInActiveGroup checks if the given peer is part of a group that is used @@ -1563,17 +1560,8 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto } for _, peer := range peers { - groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthUpdate, accountID, peer.ID) - if err != nil { - return nil, fmt.Errorf("failed to get peer groups: %w", err) - } - - for _, group := range groups { - group.RemovePeer(peer.ID) - err = transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) - if err != nil { - return nil, fmt.Errorf("failed to save group: %w", err) - } + if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { + return nil, fmt.Errorf("failed to remove peer %s from groups", peer.ID) } if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID, settings.Extra); err != nil { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 4f6ae500e..947e53a60 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -310,12 +310,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { group1.Peers = append(group1.Peers, peer1.ID) group2.Peers = append(group2.Peers, peer2.ID) - err = manager.SaveGroup(context.Background(), account.Id, userID, &group1, true) + err = manager.CreateGroup(context.Background(), account.Id, userID, &group1) if err != nil { t.Errorf("expecting group1 to be added, got failure %v", err) return } - err = manager.SaveGroup(context.Background(), account.Id, userID, &group2, true) + err = manager.CreateGroup(context.Background(), account.Id, userID, &group2) if err != nil { t.Errorf("expecting group2 to be added, got failure %v", err) return @@ -1475,6 +1475,10 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { } func Test_RegisterPeerRollbackOnFailure(t *testing.T) { + engine := os.Getenv("NETBIRD_STORE_ENGINE") + if engine == "sqlite" || engine == "" { + t.Skip("Skipping test because sqlite test store is not respecting foreign keys") + } if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } @@ -1709,7 +1713,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) require.NoError(t, err) - err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + g := []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -1725,8 +1729,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { Name: "GroupC", Peers: []string{}, }, - }, true) - require.NoError(t, err) + } + for _, group := range g { + err = manager.CreateGroup(context.Background(), account.Id, userID, group) + require.NoError(t, err) + } // create a user with auto groups _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*types.User{ @@ -1785,7 +1792,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("adding peer to unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg) // close(done) }() @@ -2164,7 +2171,6 @@ func Test_IsUniqueConstraintError(t *testing.T) { } func Test_AddPeer(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine)) manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -2176,7 +2182,7 @@ func Test_AddPeer(t *testing.T) { _, err = createAccount(manager, accountID, userID, "domain.com") if err != nil { - t.Fatal("error creating account") + t.Fatalf("error creating account: %v", err) return } @@ -2186,22 +2192,21 @@ func Test_AddPeer(t *testing.T) { return } - const totalPeers = 300 // totalPeers / differentHostnames should be less than 10 (due to concurrent retries) - const differentHostnames = 50 + const totalPeers = 300 var wg sync.WaitGroup - errs := make(chan error, totalPeers+differentHostnames) + errs := make(chan error, totalPeers) start := make(chan struct{}) for i := 0; i < totalPeers; i++ { wg.Add(1) - hostNameID := i % differentHostnames go func(i int) { defer wg.Done() newPeer := &nbpeer.Peer{ - Key: "key" + strconv.Itoa(i), - Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(hostNameID), GoOS: "linux"}, + AccountID: accountID, + Key: "key" + strconv.Itoa(i), + Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(i), GoOS: "linux"}, } <-start diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 4352f3cff..4a08f4c33 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -993,7 +993,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int { func TestPolicyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + g := []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -1014,8 +1014,11 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Name: "GroupD", Peers: []string{peer1.ID, peer2.ID}, }, - }, true) - assert.NoError(t, err) + } + for _, group := range g { + err := manager.CreateGroup(context.Background(), account.Id, userID, group) + assert.NoError(t, err) + } updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { @@ -1025,6 +1028,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { var policyWithGroupRulesNoPeers *types.Policy var policyWithDestinationPeersOnly *types.Policy var policyWithSourceAndDestinationPeers *types.Policy + var err error // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index f93467375..67760d55a 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/posture" @@ -105,10 +105,14 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er Id: regularUserID, Role: types.UserRoleUser, } + peer1 := &peer.Peer{ + ID: "peer1", + } account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false) account.Users[admin.Id] = admin account.Users[user.Id] = user + account.Peers["peer1"] = peer1 err := am.Store.SaveAccount(context.Background(), account) if err != nil { @@ -121,7 +125,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + g := []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -137,8 +141,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Name: "GroupC", Peers: []string{}, }, - }, true) - assert.NoError(t, err) + } + for _, group := range g { + err := manager.CreateGroup(context.Background(), account.Id, userID, group) + assert.NoError(t, err) + } updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { @@ -156,7 +163,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, } - postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true) + postureCheckA, err := manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true) require.NoError(t, err) postureCheckB := &posture.Checks{ @@ -449,14 +456,16 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { AccountID: account.Id, Peers: []string{"peer1"}, } + err = manager.CreateGroup(context.Background(), account.Id, adminUserID, groupA) + require.NoError(t, err, "failed to create groupA") groupB := &types.Group{ ID: "groupB", AccountID: account.Id, Peers: []string{}, } - err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, account.Id, []*types.Group{groupA, groupB}) - require.NoError(t, err, "failed to save groups") + err = manager.CreateGroup(context.Background(), account.Id, adminUserID, groupB) + require.NoError(t, err, "failed to create groupB") postureCheckA := &posture.Checks{ Name: "checkA", @@ -535,7 +544,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { groupA.Peers = []string{} - err = manager.Store.SaveGroup(context.Background(), store.LockingStrengthUpdate, groupA) + err = manager.UpdateGroup(context.Background(), account.Id, adminUserID, groupA) require.NoError(t, err, "failed to save groups") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) diff --git a/management/server/route_test.go b/management/server/route_test.go index 37c37f624..ffd1a284b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1215,7 +1215,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { Name: "peer1 group", Peers: []string{peer1ID}, } - err = am.SaveGroup(context.Background(), account.Id, userID, newGroup, true) + err = am.CreateGroup(context.Background(), account.Id, userID, newGroup) require.NoError(t, err) rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser") @@ -1505,7 +1505,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou } for _, group := range newGroup { - err = am.SaveGroup(context.Background(), accountID, userID, group, true) + err = am.CreateGroup(context.Background(), accountID, userID, group) if err != nil { return nil, err } @@ -1953,7 +1953,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { account, err := initTestRouteAccount(t, manager) require.NoError(t, err, "failed to init testing account") - err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + g := []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -1969,8 +1969,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) { Name: "GroupC", Peers: []string{}, }, - }, true) - assert.NoError(t, err) + } + for _, group := range g { + err = manager.CreateGroup(context.Background(), account.Id, userID, group) + require.NoError(t, err, "failed to create group %s", group.Name) + } updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) t.Cleanup(func() { @@ -2149,11 +2152,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1ID}, - }, true) + }) assert.NoError(t, err) select { @@ -2189,11 +2192,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1ID}, - }, true) + }) assert.NoError(t, err) select { diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index cecf55200..e55b33c94 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -29,7 +29,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ + err = manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "group_1", Name: "group_name_1", @@ -40,7 +40,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { Name: "group_name_2", Peers: []string{}, }, - }, true) + }) if err != nil { t.Fatal(err) } @@ -104,20 +104,20 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, - }, true) + }) if err != nil { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, - }, true) + }) if err != nil { t.Fatal(err) } @@ -398,11 +398,11 @@ func TestSetupKey_Copy(t *testing.T) { func TestSetupKeyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }, true) + }) assert.NoError(t, err) policy := &types.Policy{ diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index e380a7da7..c2f0dff6d 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -96,7 +96,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met return nil, fmt.Errorf("migratePreAuto: %w", err) } err = db.AutoMigrate( - &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, + &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, @@ -186,6 +186,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro generateAccountSQLTypes(account) + for _, group := range account.GroupsG { + group.StoreGroupPeers() + } + err := s.db.Transaction(func(tx *gorm.DB) error { result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) if result.Error != nil { @@ -247,7 +251,8 @@ func generateAccountSQLTypes(account *types.Account) { for id, group := range account.Groups { group.ID = id - account.GroupsG = append(account.GroupsG, *group) + group.AccountID = account.Id + account.GroupsG = append(account.GroupsG, group) } for id, route := range account.Routes { @@ -449,25 +454,56 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u return nil } -// SaveGroups saves the given list of groups to the database. -func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error { +// CreateGroups creates the given list of groups to the database. +func (s *SqlStore) CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error { if len(groups) == 0 { return nil } - result := s.db. - Clauses( - clause.Locking{Strength: string(lockStrength)}, - clause.OnConflict{ - Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, - UpdateAll: true, - }, - ). - Create(&groups) - if result.Error != nil { - return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) + return s.db.Transaction(func(tx *gorm.DB) error { + result := tx. + Clauses( + clause.Locking{Strength: string(lockStrength)}, + clause.OnConflict{ + Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, + UpdateAll: true, + }, + ). + Omit(clause.Associations). + Create(&groups) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save groups to store") + } + + return nil + }) +} + +// UpdateGroups updates the given list of groups to the database. +func (s *SqlStore) UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error { + if len(groups) == 0 { + return nil } - return nil + + return s.db.Transaction(func(tx *gorm.DB) error { + result := tx. + Clauses( + clause.Locking{Strength: string(lockStrength)}, + clause.OnConflict{ + Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, + UpdateAll: true, + }, + ). + Omit(clause.Associations). + Create(&groups) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save groups to store") + } + + return nil + }) } // DeleteHashedPAT2TokenIDIndex is noop in SqlStore @@ -646,7 +682,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr } var groups []*types.Group - result := tx.Find(&groups, accountIDCondition, accountID) + result := tx.Preload(clause.Associations).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -655,6 +691,10 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr return nil, status.Errorf(status.Internal, "failed to get account groups from the store") } + for _, g := range groups { + g.LoadGroupPeers() + } + return groups, nil } @@ -669,6 +709,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt likePattern := `%"ID":"` + resourceID + `"%` result := tx. + Preload(clause.Associations). Where("resources LIKE ?", likePattern). Find(&groups) @@ -679,6 +720,10 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt return nil, result.Error } + for _, g := range groups { + g.LoadGroupPeers() + } + return groups, nil } @@ -765,6 +810,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc var account types.Account result := s.db.Model(&account). + Omit("GroupsG"). Preload("UsersG.PATsG"). // have to be specifies as this is nester reference Preload(clause.Associations). First(&account, idQueryCondition, accountID) @@ -814,6 +860,17 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } account.GroupsG = nil + var groupPeers []types.GroupPeer + s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). + Find(&groupPeers) + for _, groupPeer := range groupPeers { + if group, ok := account.Groups[groupPeer.GroupID]; ok { + group.Peers = append(group.Peers, groupPeer.PeerID) + } else { + log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + } + } + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) for _, route := range account.RoutesG { account.Routes[route.ID] = route.Copy() @@ -1311,55 +1368,76 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string } // AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction -func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { - var group types.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&group, "account_id = ? AND name = ?", accountID, "All") - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "group 'All' not found for account") - } - return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) +func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { + var groupID string + _ = s.db.Model(types.Group{}). + Select("id"). + Where("account_id = ? AND name = ?", accountID, "All"). + Limit(1). + Scan(&groupID) + + if groupID == "" { + return status.Errorf(status.NotFound, "group 'All' not found for account %s", accountID) } - for _, existingPeerID := range group.Peers { - if existingPeerID == peerID { - return nil - } - } + err := s.db.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}}, + DoNothing: true, + }).Create(&types.GroupPeer{ + AccountID: accountID, + GroupID: groupID, + PeerID: peerID, + }).Error - group.Peers = append(group.Peers, peerID) - - if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group 'All': %s", err) + if err != nil { + return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err) } return nil } -// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction -func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error { - var group types.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID). - First(&group) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.NewGroupNotFoundError(groupID) - } - - return status.Errorf(status.Internal, "issue finding group: %s", result.Error) +// AddPeerToGroup adds a peer to a group +func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error { + peer := &types.GroupPeer{ + AccountID: accountID, + GroupID: groupID, + PeerID: peerID, } - for _, existingPeerID := range group.Peers { - if existingPeerID == peerId { - return nil - } + err := s.db.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}}, + DoNothing: true, + }).Create(peer).Error + + if err != nil { + log.WithContext(ctx).Errorf("failed to add peer %s to group %s for account %s: %v", peerID, groupID, accountID, err) + return status.Errorf(status.Internal, "failed to add peer to group") } - group.Peers = append(group.Peers, peerId) + return nil +} - if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group: %s", err) +// RemovePeerFromGroup removes a peer from a group +func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error { + err := s.db.WithContext(ctx). + Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error + + if err != nil { + log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err) + return status.Errorf(status.Internal, "failed to remove peer from group") + } + + return nil +} + +// RemovePeerFromAllGroups removes a peer from all groups +func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error { + err := s.db.WithContext(ctx). + Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error + + if err != nil { + log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err) + return status.Errorf(status.Internal, "failed to remove peer from all groups") } return nil @@ -1427,15 +1505,46 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng var groups []*types.Group query := tx. - Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId)) + Joins("JOIN group_peers ON group_peers.group_id = groups.id"). + Where("group_peers.peer_id = ?", peerId). + Preload(clause.Associations). + Find(&groups) if query.Error != nil { return nil, query.Error } + for _, group := range groups { + group.LoadGroupPeers() + } + return groups, nil } +// GetPeerGroupIDs retrieves all group IDs assigned to a specific peer in a given account. +func (s *SqlStore) GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var groupIDs []string + query := tx. + Model(&types.GroupPeer{}). + Where("account_id = ? AND peer_id = ?", accountId, peerId). + Pluck("group_id", &groupIDs) + + if query.Error != nil { + if errors.Is(query.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "no groups found for peer %s in account %s", peerId, accountId) + } + log.WithContext(ctx).Errorf("failed to get group IDs for peer %s in account %s: %v", peerId, accountId, query.Error) + return nil, status.Errorf(status.Internal, "failed to get group IDs for peer from store") + } + + return groupIDs, nil +} + // GetAccountPeers retrieves peers for an account. func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer @@ -1485,7 +1594,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt } func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error { - if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil { + if err := s.db.Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -1722,7 +1831,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt } var group *types.Group - result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID) + result := tx.Preload(clause.Associations).First(&group, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewGroupNotFoundError(groupID) @@ -1731,15 +1840,14 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt return nil, status.Errorf(status.Internal, "failed to get group from store") } + group.LoadGroupPeers() + return group, nil } // GetGroupByName retrieves a group by name and account ID. func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) { tx := s.db - if lockStrength != LockingStrengthNone { - tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) - } var group types.Group @@ -1747,16 +1855,14 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // we may need to reconsider changing the types. query := tx.Preload(clause.Associations) - switch s.storeEngine { - case types.PostgresStoreEngine: - query = query.Order("json_array_length(peers::json) DESC") - case types.MysqlStoreEngine: - query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC") - default: - query = query.Order("json_array_length(peers) DESC") - } - - result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName) + result := query. + Model(&types.Group{}). + Joins("LEFT JOIN group_peers ON group_peers.group_id = groups.id"). + Where("groups.account_id = ? AND groups.name = ?", accountID, groupName). + Group("groups.id"). + Order("COUNT(group_peers.peer_id) DESC"). + Limit(1). + First(&group) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewGroupNotFoundError(groupName) @@ -1764,6 +1870,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error) return nil, status.Errorf(status.Internal, "failed to get group by name from store") } + + group.LoadGroupPeers() + return &group, nil } @@ -1775,7 +1884,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren } var groups []*types.Group - result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) + result := tx.Preload(clause.Associations).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") @@ -1783,25 +1892,45 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren groupsMap := make(map[string]*types.Group) for _, group := range groups { + group.LoadGroupPeers() groupsMap[group.ID] = group } return groupsMap, nil } -// SaveGroup saves a group to the store. -func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) - if result.Error != nil { - log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) +// CreateGroup creates a group in the store. +func (s *SqlStore) CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { + if group == nil { + return status.Errorf(status.InvalidArgument, "group is nil") + } + + if err := s.db.Omit(clause.Associations).Create(group).Error; err != nil { + log.WithContext(ctx).Errorf("failed to save group to store: %v", err) return status.Errorf(status.Internal, "failed to save group to store") } + + return nil +} + +// UpdateGroup updates a group in the store. +func (s *SqlStore) UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { + if group == nil { + return status.Errorf(status.InvalidArgument, "group is nil") + } + + if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil { + log.WithContext(ctx).Errorf("failed to save group to store: %v", err) + return status.Errorf(status.Internal, "failed to save group to store") + } + return nil } // DeleteGroup deletes a group from the database. func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Select(clause.Associations). Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) @@ -1818,6 +1947,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength // DeleteGroups deletes groups from the database. func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { result := s.db.Clauses(clause.Locking{Strength: string(strength)}). + Select(clause.Associations). Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) @@ -2613,3 +2743,27 @@ func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain stri return count, nil } + +func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var peers []types.GroupPeer + result := tx.Find(&peers, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get account group peers from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get account group peers from store") + } + + groupPeers := make(map[string]map[string]struct{}) + for _, peer := range peers { + if _, exists := groupPeers[peer.GroupID]; !exists { + groupPeers[peer.GroupID] = make(map[string]struct{}) + } + groupPeers[peer.GroupID][peer.PeerID] = struct{}{} + } + + return groupPeers, nil +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 738c5a28c..44bb3f599 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" b64 "encoding/base64" + "encoding/binary" "fmt" "math/rand" "net" @@ -1187,7 +1188,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { Peers: nil, } err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error { - err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group) + err := transaction.CreateGroup(context.Background(), LockingStrengthUpdate, group) if err != nil { t.Fatal("failed to save group") return err @@ -1348,7 +1349,8 @@ func TestSqlStore_GetGroupsByIDs(t *testing.T) { } } -func TestSqlStore_SaveGroup(t *testing.T) { +func TestSqlStore_CreateGroup(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(types.MysqlStoreEngine)) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1356,12 +1358,14 @@ func TestSqlStore_SaveGroup(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" group := &types.Group{ - ID: "group-id", - AccountID: accountID, - Issued: "api", - Peers: []string{"peer1", "peer2"}, + ID: "group-id", + AccountID: accountID, + Issued: "api", + Peers: []string{}, + Resources: []types.Resource{}, + GroupPeers: []types.GroupPeer{}, } - err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + err = store.CreateGroup(context.Background(), LockingStrengthUpdate, group) require.NoError(t, err) savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id") @@ -1369,7 +1373,7 @@ func TestSqlStore_SaveGroup(t *testing.T) { require.Equal(t, savedGroup, group) } -func TestSqlStore_SaveGroups(t *testing.T) { +func TestSqlStore_CreateUpdateGroups(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1378,23 +1382,27 @@ func TestSqlStore_SaveGroups(t *testing.T) { groups := []*types.Group{ { - ID: "group-1", - AccountID: accountID, - Issued: "api", - Peers: []string{"peer1", "peer2"}, + ID: "group-1", + AccountID: accountID, + Issued: "api", + Peers: []string{}, + Resources: []types.Resource{}, + GroupPeers: []types.GroupPeer{}, }, { - ID: "group-2", - AccountID: accountID, - Issued: "integration", - Peers: []string{"peer3", "peer4"}, + ID: "group-2", + AccountID: accountID, + Issued: "integration", + Peers: []string{}, + Resources: []types.Resource{}, + GroupPeers: []types.GroupPeer{}, }, } - err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups) + err = store.CreateGroups(context.Background(), LockingStrengthUpdate, accountID, groups) require.NoError(t, err) groups[1].Peers = []string{} - err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups) + err = store.UpdateGroups(context.Background(), LockingStrengthUpdate, accountID, groups) require.NoError(t, err) group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groups[1].ID) @@ -2523,7 +2531,7 @@ func TestSqlStore_AddPeerToGroup(t *testing.T) { require.NoError(t, err, "failed to get group") require.Len(t, group.Peers, 0, "group should have 0 peers") - err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID) + err = store.AddPeerToGroup(context.Background(), accountID, peerID, groupID) require.NoError(t, err, "failed to add peer to group") group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) @@ -2554,7 +2562,7 @@ func TestSqlStore_AddPeerToAllGroup(t *testing.T) { err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) require.NoError(t, err, "failed to add peer to account") - err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID) + err = store.AddPeerToAllGroup(context.Background(), accountID, peer.ID) require.NoError(t, err, "failed to add peer to all group") group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) @@ -2640,7 +2648,7 @@ func TestSqlStore_GetPeerGroups(t *testing.T) { assert.Len(t, groups, 1) assert.Equal(t, groups[0].Name, "All") - err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, "cfefqs706sqkneg59g4h") + err = store.AddPeerToGroup(context.Background(), accountID, peerID, "cfefqs706sqkneg59g4h") require.NoError(t, err) groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID) @@ -3307,7 +3315,7 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) { }) } - err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave) + err = store.CreateGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave) require.NoError(t, err) accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) @@ -3538,3 +3546,64 @@ func TestSqlStore_GetAnyAccountID(t *testing.T) { assert.Empty(t, accountID) }) } + +func BenchmarkGetAccountPeers(b *testing.B) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", b.TempDir()) + if err != nil { + b.Fatal(err) + } + b.Cleanup(cleanup) + + numberOfPeers := 1000 + numberOfGroups := 200 + numberOfPeersPerGroup := 500 + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + peers := make([]*nbpeer.Peer, 0, numberOfPeers) + for i := 0; i < numberOfPeers; i++ { + peer := &nbpeer.Peer{ + ID: fmt.Sprintf("peer-%d", i), + AccountID: accountID, + DNSLabel: fmt.Sprintf("peer%d.example.com", i), + IP: intToIPv4(uint32(i)), + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + if err != nil { + b.Fatalf("Failed to add peer: %v", err) + } + peers = append(peers, peer) + } + + for i := 0; i < numberOfGroups; i++ { + groupID := fmt.Sprintf("group-%d", i) + group := &types.Group{ + ID: groupID, + AccountID: accountID, + } + err = store.CreateGroup(context.Background(), LockingStrengthUpdate, group) + if err != nil { + b.Fatalf("Failed to create group: %v", err) + } + for j := 0; j < numberOfPeersPerGroup; j++ { + peerIndex := (i*numberOfPeersPerGroup + j) % numberOfPeers + err = store.AddPeerToGroup(context.Background(), accountID, peers[peerIndex].ID, groupID) + if err != nil { + b.Fatalf("Failed to add peer to group: %v", err) + } + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peers[i%numberOfPeers].ID) + if err != nil { + b.Fatal(err) + } + } +} + +func intToIPv4(n uint32) net.IP { + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, n) + return ip +} diff --git a/management/server/store/store.go b/management/server/store/store.go index b3254c4c9..912939bc2 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -101,8 +101,10 @@ type Store interface { GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) - SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error - SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error + CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error + UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error + CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error + UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error @@ -120,9 +122,12 @@ type Store interface { DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error) - AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error - AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error + AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error + AddPeerToGroup(ctx context.Context, accountID, peerId string, groupID string) error + RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error + RemovePeerFromAllGroups(ctx context.Context, peerID string) error GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) + GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error @@ -196,6 +201,7 @@ type Store interface { DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) + GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) } const ( @@ -353,6 +359,15 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label") }, + func(db *gorm.DB) error { + return migration.MigrateJsonToTable[types.Group](ctx, db, "peers", func(accountID, id, value string) any { + return &types.GroupPeer{ + AccountID: accountID, + GroupID: id, + PeerID: value, + } + }) + }, } } diff --git a/management/server/types/account.go b/management/server/types/account.go index f0887be07..a3a7ce305 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -73,7 +73,7 @@ type Account struct { Users map[string]*User `gorm:"-"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` Groups map[string]*Group `gorm:"-"` - GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Routes map[route.ID]*route.Route `gorm:"-"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` diff --git a/management/server/types/group.go b/management/server/types/group.go index 1b321387c..00fdf7a69 100644 --- a/management/server/types/group.go +++ b/management/server/types/group.go @@ -26,7 +26,8 @@ type Group struct { Issued string // Peers list of the group - Peers []string `gorm:"serializer:json"` + Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership + GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"` // Resources contains a list of resources in that group Resources []Resource `gorm:"serializer:json"` @@ -34,6 +35,32 @@ type Group struct { IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } +type GroupPeer struct { + AccountID string `gorm:"index"` + GroupID string `gorm:"primaryKey"` + PeerID string `gorm:"primaryKey"` +} + +func (g *Group) LoadGroupPeers() { + g.Peers = make([]string, len(g.GroupPeers)) + for i, peer := range g.GroupPeers { + g.Peers[i] = peer.PeerID + } + g.GroupPeers = []GroupPeer{} +} + +func (g *Group) StoreGroupPeers() { + g.GroupPeers = make([]GroupPeer, len(g.Peers)) + for i, peer := range g.Peers { + g.GroupPeers[i] = GroupPeer{ + AccountID: g.AccountID, + GroupID: g.ID, + PeerID: peer, + } + } + g.Peers = []string{} +} + // EventMeta returns activity event meta related to the group func (g *Group) EventMeta() map[string]any { return map[string]any{"name": g.Name} @@ -46,13 +73,16 @@ func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]an func (g *Group) Copy() *Group { group := &Group{ ID: g.ID, + AccountID: g.AccountID, Name: g.Name, Issued: g.Issued, Peers: make([]string, len(g.Peers)), + GroupPeers: make([]GroupPeer, len(g.GroupPeers)), Resources: make([]Resource, len(g.Resources)), IntegrationReference: g.IntegrationReference, } copy(group.Peers, g.Peers) + copy(group.GroupPeers, g.GroupPeers) copy(group.Resources, g.Resources) return group } diff --git a/management/server/types/setupkey.go b/management/server/types/setupkey.go index 69b381ae5..3d421342d 100644 --- a/management/server/types/setupkey.go +++ b/management/server/types/setupkey.go @@ -35,7 +35,7 @@ type SetupKey struct { // AccountID is a reference to Account that this object belongs AccountID string `json:"-" gorm:"index"` Key string - KeySecret string + KeySecret string `gorm:"index"` Name string Type SetupKeyType CreatedAt time.Time diff --git a/management/server/user.go b/management/server/user.go index 7d8382978..a0f4c4a6c 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -677,13 +677,18 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact if update.AutoGroups != nil && settings.GroupsPropagationEnabled { removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups) - updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups) - if err != nil { - return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err) - } - - if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, updatedGroups); err != nil { - return false, nil, nil, nil, fmt.Errorf("error saving groups: %w", err) + addedGroups := util.Difference(update.AutoGroups, oldUser.AutoGroups) + for _, peer := range userPeers { + for _, groupID := range removedGroups { + if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil { + return false, nil, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, groupID, err) + } + } + for _, groupID := range addedGroups { + if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil { + return false, nil, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, groupID, err) + } + } } } @@ -1137,93 +1142,6 @@ func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID str return userInfo, nil } -// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. -func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) { - if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { - return - } - - userPeerIDMap := make(map[string]struct{}, len(peers)) - for _, peer := range peers { - userPeerIDMap[peer.ID] = struct{}{} - } - - for _, gid := range groupsToAdd { - group, ok := accountGroups[gid] - if !ok { - return nil, errors.New("group not found") - } - if changed := addUserPeersToGroup(userPeerIDMap, group); changed { - groupsToUpdate = append(groupsToUpdate, group) - } - } - - for _, gid := range groupsToRemove { - group, ok := accountGroups[gid] - if !ok { - return nil, errors.New("group not found") - } - if changed := removeUserPeersFromGroup(userPeerIDMap, group); changed { - groupsToUpdate = append(groupsToUpdate, group) - } - } - - return groupsToUpdate, nil -} - -// addUserPeersToGroup adds the user's peers to the group. -func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) bool { - groupPeers := make(map[string]struct{}, len(group.Peers)) - for _, pid := range group.Peers { - groupPeers[pid] = struct{}{} - } - - changed := false - for pid := range userPeerIDs { - if _, exists := groupPeers[pid]; !exists { - groupPeers[pid] = struct{}{} - changed = true - } - } - - group.Peers = make([]string, 0, len(groupPeers)) - for pid := range groupPeers { - group.Peers = append(group.Peers, pid) - } - - if changed { - group.Peers = make([]string, 0, len(groupPeers)) - for pid := range groupPeers { - group.Peers = append(group.Peers, pid) - } - } - return changed -} - -// removeUserPeersFromGroup removes user's peers from the group. -func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) bool { - // skip removing peers from group All - if group.Name == "All" { - return false - } - - updatedPeers := make([]string, 0, len(group.Peers)) - changed := false - - for _, pid := range group.Peers { - if _, owned := userPeerIDs[pid]; owned { - changed = true - continue - } - updatedPeers = append(updatedPeers, pid) - } - - if changed { - group.Peers = updatedPeers - } - return changed -} - func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { for _, user := range userData { if user.ID == userID { diff --git a/management/server/user_test.go b/management/server/user_test.go index 53baf8f7e..8ab6584cf 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1335,11 +1335,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { // account groups propagation is enabled manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ + err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }, true) + }) require.NoError(t, err) policy := &types.Policy{