mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
[management] migrate group peers into seperate table (#4096)
This commit is contained in:
@@ -1368,7 +1368,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil {
|
||||
if err = transaction.CreateGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil {
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
@@ -1382,28 +1382,22 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
|
||||
// Propagate changes to peers if group propagation is enabled
|
||||
if settings.GroupsPropagationEnabled {
|
||||
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
|
||||
groupsMap := make(map[string]*types.Group, len(groups))
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user peers: %w", err)
|
||||
}
|
||||
|
||||
updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error modifying user peers in groups: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, updatedGroups); err != nil {
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
for _, peer := range peers {
|
||||
for _, g := range addNewGroups {
|
||||
if err := transaction.AddPeerToGroup(ctx, userAuth.AccountId, peer.ID, g); err != nil {
|
||||
return fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, g, err)
|
||||
}
|
||||
}
|
||||
for _, g := range removeOldGroups {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, g); err != nil {
|
||||
return fmt.Errorf("error removing peer %s from group %s: %w", peer.ID, g, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil {
|
||||
@@ -1971,53 +1965,56 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
||||
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
|
||||
// Returns true if any groups were modified, true if those updates affect peers and an error.
|
||||
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
|
||||
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
groupsMap := make(map[string]*types.Group, len(groups))
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
groupsToUpdate := make(map[string]*types.Group)
|
||||
accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, false, fmt.Errorf("error getting account group peers: %w", err)
|
||||
}
|
||||
|
||||
accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, false, fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
|
||||
for _, group := range accountGroups {
|
||||
if _, exists := accountGroupPeers[group.ID]; !exists {
|
||||
accountGroupPeers[group.ID] = make(map[string]struct{})
|
||||
}
|
||||
}
|
||||
|
||||
updatedGroups := []string{}
|
||||
for _, user := range users {
|
||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, user.AutoGroups, nil)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
for _, group := range updatedGroups {
|
||||
groupsToUpdate[group.ID] = group
|
||||
groupsMap[group.ID] = group
|
||||
for _, peer := range userPeers {
|
||||
for _, groupID := range user.AutoGroups {
|
||||
if _, exists := accountGroupPeers[groupID]; !exists {
|
||||
// we do not wanna create the groups here
|
||||
log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID)
|
||||
continue
|
||||
}
|
||||
if _, exists := accountGroupPeers[groupID][peer.ID]; exists {
|
||||
continue
|
||||
}
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
|
||||
return false, false, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err)
|
||||
}
|
||||
updatedGroups = append(updatedGroups, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupsToUpdate) == 0 {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, maps.Keys(groupsToUpdate))
|
||||
peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, updatedGroups)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
return false, false, fmt.Errorf("error checking if group changes affect peers: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, maps.Values(groupsToUpdate))
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
}
|
||||
|
||||
return true, peersAffected, nil
|
||||
return len(updatedGroups) > 0, peersAffected, nil
|
||||
}
|
||||
|
||||
@@ -62,8 +62,10 @@ type Manager interface {
|
||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
||||
SaveGroup(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
|
||||
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group, create bool) error
|
||||
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
|
||||
UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
|
||||
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
||||
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
|
||||
@@ -1159,7 +1159,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1194,7 +1194,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
}()
|
||||
|
||||
group.Peers = []string{peer1.ID, peer2.ID, peer3.ID}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
if err := manager.UpdateGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1240,11 +1240,12 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
AccountID: account.Id,
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1292,7 +1293,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1343,11 +1344,11 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
|
||||
require.NoError(t, err, "failed to save group")
|
||||
|
||||
@@ -1672,9 +1673,10 @@ func TestAccount_Copy(t *testing.T) {
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
Peers: []string{"peer1"},
|
||||
Resources: []types.Resource{},
|
||||
ID: "group1",
|
||||
Peers: []string{"peer1"},
|
||||
Resources: []types.Resource{},
|
||||
GroupPeers: []types.GroupPeer{},
|
||||
},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
@@ -2616,6 +2618,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
@@ -3360,7 +3363,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
|
||||
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
|
||||
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
|
||||
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1))
|
||||
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group1))
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
|
||||
require.NoError(t, err)
|
||||
@@ -3382,7 +3385,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
|
||||
t.Run("should update membership and account peers for used groups", func(t *testing.T) {
|
||||
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
|
||||
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2))
|
||||
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group2))
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -495,7 +495,7 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -506,7 +506,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
@@ -562,11 +562,11 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// Creating DNS settings with groups that have peers should update account peers and send peer update
|
||||
t.Run("creating dns setting with used groups", func(t *testing.T) {
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
@@ -65,22 +65,144 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
|
||||
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName)
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group, create bool) error {
|
||||
// CreateGroup object of the peers
|
||||
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}, create)
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateGroup(ctx, store.LockingStrengthUpdate, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveGroups adds new groups to the account.
|
||||
// UpdateGroup object of the peers
|
||||
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
|
||||
}
|
||||
|
||||
peersToAdd := util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||
peersToRemove := util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||
|
||||
for _, peerID := range peersToAdd {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
for _, peerID := range peersToRemove {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, newGroup)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateGroups adds new groups to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error {
|
||||
operation := operations.Create
|
||||
if !create {
|
||||
operation = operations.Update
|
||||
}
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operation)
|
||||
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
|
||||
func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
@@ -116,7 +238,65 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave)
|
||||
return transaction.CreateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateGroups updates groups in the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
|
||||
func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var groupsToSave []*types.Group
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
groupsToSave = append(groupsToSave, newGroup)
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.UpdateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -265,20 +445,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.AddPeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -288,7 +458,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
return transaction.AddPeerToGroup(ctx, accountID, peerID, groupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -329,7 +499,7 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -347,20 +517,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.RemovePeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -370,7 +530,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
return transaction.RemovePeerFromGroup(ctx, peerID, groupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -411,7 +571,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -2,14 +2,20 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -18,8 +24,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
peer2 "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -40,7 +48,8 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
}
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = types.GroupIssuedIntegration
|
||||
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
|
||||
group.ID = uuid.New().String()
|
||||
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group)
|
||||
if err != nil {
|
||||
t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration)
|
||||
}
|
||||
@@ -48,7 +57,8 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = types.GroupIssuedJWT
|
||||
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
|
||||
group.ID = uuid.New().String()
|
||||
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group)
|
||||
if err != nil {
|
||||
t.Errorf("should allow to create %s groups", types.GroupIssuedJWT)
|
||||
}
|
||||
@@ -56,7 +66,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = types.GroupIssuedAPI
|
||||
group.ID = ""
|
||||
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
|
||||
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group)
|
||||
if err == nil {
|
||||
t.Errorf("should not create api group with the same name, %s", group.Name)
|
||||
}
|
||||
@@ -162,7 +172,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups, true)
|
||||
err = manager.CreateGroups(context.Background(), account.Id, groupAdminUserID, groups)
|
||||
assert.NoError(t, err, "Failed to save test groups")
|
||||
|
||||
testCases := []struct {
|
||||
@@ -382,13 +392,13 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute, true)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2, true)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups, true)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies, true)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys, true)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers, true)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration, true)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
|
||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
@@ -400,7 +410,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
||||
func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
g := []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -426,8 +436,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
Name: "GroupE",
|
||||
Peers: []string{peer2.ID},
|
||||
},
|
||||
}, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
for _, group := range g {
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, group)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
@@ -442,11 +455,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -513,7 +526,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
|
||||
// adding a group to policy
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
@@ -535,11 +548,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -604,11 +617,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -645,11 +658,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -672,11 +685,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupD",
|
||||
Name: "GroupD",
|
||||
Peers: []string{peer1.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -719,11 +732,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupE",
|
||||
Name: "GroupE",
|
||||
Peers: []string{peer2.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -733,3 +746,259 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AddPeerToGroup(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
acc, err := createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
err = manager.Store.AddPeerToGroup(context.Background(), accountID, strconv.Itoa(i), acc.GroupsG[0].ID)
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
|
||||
}
|
||||
|
||||
func Test_AddPeerToAll(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
err = manager.Store.AddPeerToAllGroup(context.Background(), accountID, strconv.Itoa(i))
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
|
||||
}
|
||||
|
||||
func Test_AddPeerAndAddToAll(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
peer := &peer2.Peer{
|
||||
ID: strconv.Itoa(i),
|
||||
AccountID: accountID,
|
||||
DNSLabel: "peer" + strconv.Itoa(i),
|
||||
IP: uint32ToIP(uint32(i)),
|
||||
}
|
||||
|
||||
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
|
||||
err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
}
|
||||
err = transaction.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("AddPeer failed for peer %d: %v", i, err)
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
|
||||
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
|
||||
}
|
||||
|
||||
func uint32ToIP(n uint32) net.IP {
|
||||
ip := make(net.IP, 4)
|
||||
binary.BigEndian.PutUint32(ip, n)
|
||||
return ip
|
||||
}
|
||||
|
||||
func Test_IncrementNetworkSerial(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
|
||||
err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("AddPeer failed for peer %d: %v", i, err)
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, int(account.Network.Serial), "Expected %d serial increases in account %s, got %d", totalPeers, accountID, account.Network.Serial)
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
IntegrationReference: existingGroup.IntegrationReference,
|
||||
}
|
||||
|
||||
if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, false); err != nil {
|
||||
if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -203,7 +203,7 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
|
||||
Issued: types.GroupIssuedAPI,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, true)
|
||||
err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -39,6 +39,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f
|
||||
return nil
|
||||
}
|
||||
|
||||
if !db.Migrator().HasColumn(&model, fieldName) {
|
||||
log.WithContext(ctx).Debugf("Table for %T does not have column %s, no migration needed", model, fieldName)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(model)
|
||||
if err != nil {
|
||||
@@ -422,3 +427,62 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
|
||||
log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName string, mapperFunc func(accountID string, id string, value string) any) error {
|
||||
var model T
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(&model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
|
||||
if !db.Migrator().HasColumn(&model, columnName) {
|
||||
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no migration needed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
var rows []map[string]any
|
||||
if err := tx.Table(tableName).Select("id", "account_id", columnName).Find(&rows).Error; err != nil {
|
||||
return fmt.Errorf("find rows: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
jsonValue, ok := row[columnName].(string)
|
||||
if !ok || jsonValue == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var data []string
|
||||
if err := json.Unmarshal([]byte(jsonValue), &data); err != nil {
|
||||
return fmt.Errorf("unmarshal json: %w", err)
|
||||
}
|
||||
|
||||
for _, value := range data {
|
||||
if err := tx.Create(
|
||||
mapperFunc(row["account_id"].(string), row["id"].(string), value),
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("failed to insert id %v: %w", row["id"], err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Migrator().DropColumn(&model, columnName); err != nil {
|
||||
return fmt.Errorf("drop column %s: %w", columnName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -124,6 +124,34 @@ type MockAccountManager struct {
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
|
||||
if am.SaveGroupFunc != nil {
|
||||
return am.SaveGroupFunc(ctx, accountID, userID, group, true)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method CreateGroup is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
|
||||
if am.SaveGroupFunc != nil {
|
||||
return am.SaveGroupFunc(ctx, accountID, userID, group, false)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdateGroup is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error {
|
||||
if am.SaveGroupsFunc != nil {
|
||||
return am.SaveGroupsFunc(ctx, accountID, userID, newGroups, true)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method CreateGroups is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error {
|
||||
if am.SaveGroupsFunc != nil {
|
||||
return am.SaveGroupsFunc(ctx, accountID, userID, newGroups, false)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdateGroups is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
if am.UpdateAccountPeersFunc != nil {
|
||||
am.UpdateAccountPeersFunc(ctx, accountID)
|
||||
|
||||
@@ -980,18 +980,18 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
var newNameServerGroupA *nbdns.NameServerGroup
|
||||
var newNameServerGroupB *nbdns.NameServerGroup
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
},
|
||||
}, true)
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
|
||||
@@ -374,12 +374,20 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
|
||||
return fmt.Errorf("failed to remove peer from groups: %w", err)
|
||||
}
|
||||
|
||||
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete peer: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -478,7 +486,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
var newPeer *nbpeer.Peer
|
||||
var updateAccountPeers bool
|
||||
|
||||
var setupKeyID string
|
||||
var setupKeyName string
|
||||
@@ -615,20 +622,20 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return err
|
||||
}
|
||||
|
||||
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||
}
|
||||
|
||||
if len(groupsToAdd) > 0 {
|
||||
for _, g := range groupsToAdd {
|
||||
err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g)
|
||||
err = transaction.AddPeerToGroup(ctx, newPeer.AccountID, newPeer.ID, g)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||
}
|
||||
|
||||
if addedByUser {
|
||||
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
||||
if err != nil {
|
||||
@@ -678,7 +685,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
|
||||
}
|
||||
|
||||
updateAccountPeers, err = isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
|
||||
updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
@@ -1021,7 +1028,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
}()
|
||||
|
||||
if isRequiresApproval {
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -1523,17 +1530,7 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p
|
||||
|
||||
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
|
||||
func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) {
|
||||
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
groupIDs = append(groupIDs, group.ID)
|
||||
}
|
||||
|
||||
return groupIDs, err
|
||||
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||
}
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
@@ -1563,17 +1560,8 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
}
|
||||
|
||||
for _, peer := range peers {
|
||||
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthUpdate, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get peer groups: %w", err)
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
group.RemovePeer(peer.ID)
|
||||
err = transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save group: %w", err)
|
||||
}
|
||||
if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
|
||||
return nil, fmt.Errorf("failed to remove peer %s from groups", peer.ID)
|
||||
}
|
||||
|
||||
if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID, settings.Extra); err != nil {
|
||||
|
||||
@@ -310,12 +310,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
group1.Peers = append(group1.Peers, peer1.ID)
|
||||
group2.Peers = append(group2.Peers, peer2.ID)
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group1, true)
|
||||
err = manager.CreateGroup(context.Background(), account.Id, userID, &group1)
|
||||
if err != nil {
|
||||
t.Errorf("expecting group1 to be added, got failure %v", err)
|
||||
return
|
||||
}
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group2, true)
|
||||
err = manager.CreateGroup(context.Background(), account.Id, userID, &group2)
|
||||
if err != nil {
|
||||
t.Errorf("expecting group2 to be added, got failure %v", err)
|
||||
return
|
||||
@@ -1475,6 +1475,10 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
engine := os.Getenv("NETBIRD_STORE_ENGINE")
|
||||
if engine == "sqlite" || engine == "" {
|
||||
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
@@ -1709,7 +1713,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
g := []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -1725,8 +1729,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for _, group := range g {
|
||||
err = manager.CreateGroup(context.Background(), account.Id, userID, group)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// create a user with auto groups
|
||||
_, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*types.User{
|
||||
@@ -1785,7 +1792,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg) //
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2164,7 +2171,6 @@ func Test_IsUniqueConstraintError(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_AddPeer(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -2176,7 +2182,7 @@ func Test_AddPeer(t *testing.T) {
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
t.Fatalf("error creating account: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2186,22 +2192,21 @@ func Test_AddPeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 300 // totalPeers / differentHostnames should be less than 10 (due to concurrent retries)
|
||||
const differentHostnames = 50
|
||||
const totalPeers = 300
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers+differentHostnames)
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
hostNameID := i % differentHostnames
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
Key: "key" + strconv.Itoa(i),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(hostNameID), GoOS: "linux"},
|
||||
AccountID: accountID,
|
||||
Key: "key" + strconv.Itoa(i),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(i), GoOS: "linux"},
|
||||
}
|
||||
|
||||
<-start
|
||||
|
||||
@@ -993,7 +993,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int {
|
||||
func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
g := []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -1014,8 +1014,11 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
Name: "GroupD",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
},
|
||||
}, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
for _, group := range g {
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, group)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
@@ -1025,6 +1028,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
var policyWithGroupRulesNoPeers *types.Policy
|
||||
var policyWithDestinationPeersOnly *types.Policy
|
||||
var policyWithSourceAndDestinationPeers *types.Policy
|
||||
var err error
|
||||
|
||||
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -105,10 +105,14 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
|
||||
Id: regularUserID,
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
peer1 := &peer.Peer{
|
||||
ID: "peer1",
|
||||
}
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account.Users[admin.Id] = admin
|
||||
account.Users[user.Id] = user
|
||||
account.Peers["peer1"] = peer1
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -121,7 +125,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
|
||||
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
g := []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -137,8 +141,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
}, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
for _, group := range g {
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, group)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
t.Cleanup(func() {
|
||||
@@ -156,7 +163,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true)
|
||||
postureCheckA, err := manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
postureCheckB := &posture.Checks{
|
||||
@@ -449,14 +456,16 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
AccountID: account.Id,
|
||||
Peers: []string{"peer1"},
|
||||
}
|
||||
err = manager.CreateGroup(context.Background(), account.Id, adminUserID, groupA)
|
||||
require.NoError(t, err, "failed to create groupA")
|
||||
|
||||
groupB := &types.Group{
|
||||
ID: "groupB",
|
||||
AccountID: account.Id,
|
||||
Peers: []string{},
|
||||
}
|
||||
err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, account.Id, []*types.Group{groupA, groupB})
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
err = manager.CreateGroup(context.Background(), account.Id, adminUserID, groupB)
|
||||
require.NoError(t, err, "failed to create groupB")
|
||||
|
||||
postureCheckA := &posture.Checks{
|
||||
Name: "checkA",
|
||||
@@ -535,7 +544,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||
|
||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||
groupA.Peers = []string{}
|
||||
err = manager.Store.SaveGroup(context.Background(), store.LockingStrengthUpdate, groupA)
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, adminUserID, groupA)
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
|
||||
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||
|
||||
@@ -1215,7 +1215,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
Name: "peer1 group",
|
||||
Peers: []string{peer1ID},
|
||||
}
|
||||
err = am.SaveGroup(context.Background(), account.Id, userID, newGroup, true)
|
||||
err = am.CreateGroup(context.Background(), account.Id, userID, newGroup)
|
||||
require.NoError(t, err)
|
||||
|
||||
rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser")
|
||||
@@ -1505,7 +1505,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
|
||||
}
|
||||
|
||||
for _, group := range newGroup {
|
||||
err = am.SaveGroup(context.Background(), accountID, userID, group, true)
|
||||
err = am.CreateGroup(context.Background(), accountID, userID, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1953,7 +1953,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
account, err := initTestRouteAccount(t, manager)
|
||||
require.NoError(t, err, "failed to init testing account")
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
g := []*types.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -1969,8 +1969,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
}, true)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
for _, group := range g {
|
||||
err = manager.CreateGroup(context.Background(), account.Id, userID, group)
|
||||
require.NoError(t, err, "failed to create group %s", group.Name)
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
|
||||
t.Cleanup(func() {
|
||||
@@ -2149,11 +2152,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -2189,11 +2192,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
|
||||
@@ -29,7 +29,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
err = manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
|
||||
{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
@@ -40,7 +40,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
},
|
||||
}, true)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -104,20 +104,20 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
}, true)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "group_2",
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
}, true)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -398,11 +398,11 @@ func TestSetupKey_Copy(t *testing.T) {
|
||||
func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
policy := &types.Policy{
|
||||
|
||||
@@ -96,7 +96,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
||||
}
|
||||
err = db.AutoMigrate(
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
@@ -186,6 +186,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
|
||||
generateAccountSQLTypes(account)
|
||||
|
||||
for _, group := range account.GroupsG {
|
||||
group.StoreGroupPeers()
|
||||
}
|
||||
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
||||
if result.Error != nil {
|
||||
@@ -247,7 +251,8 @@ func generateAccountSQLTypes(account *types.Account) {
|
||||
|
||||
for id, group := range account.Groups {
|
||||
group.ID = id
|
||||
account.GroupsG = append(account.GroupsG, *group)
|
||||
group.AccountID = account.Id
|
||||
account.GroupsG = append(account.GroupsG, group)
|
||||
}
|
||||
|
||||
for id, route := range account.Routes {
|
||||
@@ -449,25 +454,56 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveGroups saves the given list of groups to the database.
|
||||
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error {
|
||||
// CreateGroups creates the given list of groups to the database.
|
||||
func (s *SqlStore) CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error {
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.
|
||||
Clauses(
|
||||
clause.Locking{Strength: string(lockStrength)},
|
||||
clause.OnConflict{
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
},
|
||||
).
|
||||
Create(&groups)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.
|
||||
Clauses(
|
||||
clause.Locking{Strength: string(lockStrength)},
|
||||
clause.OnConflict{
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
},
|
||||
).
|
||||
Omit(clause.Associations).
|
||||
Create(&groups)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save groups to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateGroups updates the given list of groups to the database.
|
||||
func (s *SqlStore) UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error {
|
||||
if len(groups) == 0 {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.
|
||||
Clauses(
|
||||
clause.Locking{Strength: string(lockStrength)},
|
||||
clause.OnConflict{
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
},
|
||||
).
|
||||
Omit(clause.Associations).
|
||||
Create(&groups)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save groups to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
||||
@@ -646,7 +682,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
|
||||
}
|
||||
|
||||
var groups []*types.Group
|
||||
result := tx.Find(&groups, accountIDCondition, accountID)
|
||||
result := tx.Preload(clause.Associations).Find(&groups, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
@@ -655,6 +691,10 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
|
||||
return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
g.LoadGroupPeers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
@@ -669,6 +709,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
|
||||
likePattern := `%"ID":"` + resourceID + `"%`
|
||||
|
||||
result := tx.
|
||||
Preload(clause.Associations).
|
||||
Where("resources LIKE ?", likePattern).
|
||||
Find(&groups)
|
||||
|
||||
@@ -679,6 +720,10 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
g.LoadGroupPeers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
@@ -765,6 +810,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
|
||||
var account types.Account
|
||||
result := s.db.Model(&account).
|
||||
Omit("GroupsG").
|
||||
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
||||
Preload(clause.Associations).
|
||||
First(&account, idQueryCondition, accountID)
|
||||
@@ -814,6 +860,17 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
}
|
||||
account.GroupsG = nil
|
||||
|
||||
var groupPeers []types.GroupPeer
|
||||
s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
|
||||
Find(&groupPeers)
|
||||
for _, groupPeer := range groupPeers {
|
||||
if group, ok := account.Groups[groupPeer.GroupID]; ok {
|
||||
group.Peers = append(group.Peers, groupPeer.PeerID)
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
|
||||
for _, route := range account.RoutesG {
|
||||
account.Routes[route.ID] = route.Copy()
|
||||
@@ -1311,55 +1368,76 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
}
|
||||
|
||||
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
|
||||
var group types.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&group, "account_id = ? AND name = ?", accountID, "All")
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
var groupID string
|
||||
_ = s.db.Model(types.Group{}).
|
||||
Select("id").
|
||||
Where("account_id = ? AND name = ?", accountID, "All").
|
||||
Limit(1).
|
||||
Scan(&groupID)
|
||||
|
||||
if groupID == "" {
|
||||
return status.Errorf(status.NotFound, "group 'All' not found for account %s", accountID)
|
||||
}
|
||||
|
||||
for _, existingPeerID := range group.Peers {
|
||||
if existingPeerID == peerID {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
err := s.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(&types.GroupPeer{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
PeerID: peerID,
|
||||
}).Error
|
||||
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error {
|
||||
var group types.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID).
|
||||
First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.NewGroupNotFoundError(groupID)
|
||||
}
|
||||
|
||||
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
||||
// AddPeerToGroup adds a peer to a group
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
|
||||
peer := &types.GroupPeer{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
PeerID: peerID,
|
||||
}
|
||||
|
||||
for _, existingPeerID := range group.Peers {
|
||||
if existingPeerID == peerId {
|
||||
return nil
|
||||
}
|
||||
err := s.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(peer).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to add peer %s to group %s for account %s: %v", peerID, groupID, accountID, err)
|
||||
return status.Errorf(status.Internal, "failed to add peer to group")
|
||||
}
|
||||
|
||||
group.Peers = append(group.Peers, peerId)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue updating group: %s", err)
|
||||
// RemovePeerFromGroup removes a peer from a group
|
||||
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
|
||||
err := s.db.WithContext(ctx).
|
||||
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err)
|
||||
return status.Errorf(status.Internal, "failed to remove peer from group")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeerFromAllGroups removes a peer from all groups
|
||||
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
|
||||
err := s.db.WithContext(ctx).
|
||||
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err)
|
||||
return status.Errorf(status.Internal, "failed to remove peer from all groups")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1427,15 +1505,46 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
var groups []*types.Group
|
||||
query := tx.
|
||||
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
|
||||
Joins("JOIN group_peers ON group_peers.group_id = groups.id").
|
||||
Where("group_peers.peer_id = ?", peerId).
|
||||
Preload(clause.Associations).
|
||||
Find(&groups)
|
||||
|
||||
if query.Error != nil {
|
||||
return nil, query.Error
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
group.LoadGroupPeers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// GetPeerGroupIDs retrieves all group IDs assigned to a specific peer in a given account.
|
||||
func (s *SqlStore) GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var groupIDs []string
|
||||
query := tx.
|
||||
Model(&types.GroupPeer{}).
|
||||
Where("account_id = ? AND peer_id = ?", accountId, peerId).
|
||||
Pluck("group_id", &groupIDs)
|
||||
|
||||
if query.Error != nil {
|
||||
if errors.Is(query.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "no groups found for peer %s in account %s", peerId, accountId)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get group IDs for peer %s in account %s: %v", peerId, accountId, query.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get group IDs for peer from store")
|
||||
}
|
||||
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
// GetAccountPeers retrieves peers for an account.
|
||||
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
|
||||
var peers []*nbpeer.Peer
|
||||
@@ -1485,7 +1594,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error {
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil {
|
||||
if err := s.db.Create(peer).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
||||
}
|
||||
|
||||
@@ -1722,7 +1831,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
var group *types.Group
|
||||
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
|
||||
result := tx.Preload(clause.Associations).First(&group, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewGroupNotFoundError(groupID)
|
||||
@@ -1731,15 +1840,14 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
||||
return nil, status.Errorf(status.Internal, "failed to get group from store")
|
||||
}
|
||||
|
||||
group.LoadGroupPeers()
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// GetGroupByName retrieves a group by name and account ID.
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var group types.Group
|
||||
|
||||
@@ -1747,16 +1855,14 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
// we may need to reconsider changing the types.
|
||||
query := tx.Preload(clause.Associations)
|
||||
|
||||
switch s.storeEngine {
|
||||
case types.PostgresStoreEngine:
|
||||
query = query.Order("json_array_length(peers::json) DESC")
|
||||
case types.MysqlStoreEngine:
|
||||
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
|
||||
default:
|
||||
query = query.Order("json_array_length(peers) DESC")
|
||||
}
|
||||
|
||||
result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
|
||||
result := query.
|
||||
Model(&types.Group{}).
|
||||
Joins("LEFT JOIN group_peers ON group_peers.group_id = groups.id").
|
||||
Where("groups.account_id = ? AND groups.name = ?", accountID, groupName).
|
||||
Group("groups.id").
|
||||
Order("COUNT(group_peers.peer_id) DESC").
|
||||
Limit(1).
|
||||
First(&group)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewGroupNotFoundError(groupName)
|
||||
@@ -1764,6 +1870,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
|
||||
}
|
||||
|
||||
group.LoadGroupPeers()
|
||||
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
@@ -1775,7 +1884,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
}
|
||||
|
||||
var groups []*types.Group
|
||||
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
result := tx.Preload(clause.Associations).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
|
||||
@@ -1783,25 +1892,45 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
groupsMap := make(map[string]*types.Group)
|
||||
for _, group := range groups {
|
||||
group.LoadGroupPeers()
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
return groupsMap, nil
|
||||
}
|
||||
|
||||
// SaveGroup saves a group to the store.
|
||||
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
|
||||
// CreateGroup creates a group in the store.
|
||||
func (s *SqlStore) CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
|
||||
if group == nil {
|
||||
return status.Errorf(status.InvalidArgument, "group is nil")
|
||||
}
|
||||
|
||||
if err := s.db.Omit(clause.Associations).Create(group).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to save group to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateGroup updates a group in the store.
|
||||
func (s *SqlStore) UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
|
||||
if group == nil {
|
||||
return status.Errorf(status.InvalidArgument, "group is nil")
|
||||
}
|
||||
|
||||
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to save group to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroup deletes a group from the database.
|
||||
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Select(clause.Associations).
|
||||
Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
|
||||
@@ -1818,6 +1947,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength
|
||||
// DeleteGroups deletes groups from the database.
|
||||
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
|
||||
Select(clause.Associations).
|
||||
Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
|
||||
@@ -2613,3 +2743,27 @@ func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain stri
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var peers []types.GroupPeer
|
||||
result := tx.Find(&peers, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account group peers from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get account group peers from store")
|
||||
}
|
||||
|
||||
groupPeers := make(map[string]map[string]struct{})
|
||||
for _, peer := range peers {
|
||||
if _, exists := groupPeers[peer.GroupID]; !exists {
|
||||
groupPeers[peer.GroupID] = make(map[string]struct{})
|
||||
}
|
||||
groupPeers[peer.GroupID][peer.PeerID] = struct{}{}
|
||||
}
|
||||
|
||||
return groupPeers, nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
@@ -1187,7 +1188,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
Peers: nil,
|
||||
}
|
||||
err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
|
||||
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
err := transaction.CreateGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
if err != nil {
|
||||
t.Fatal("failed to save group")
|
||||
return err
|
||||
@@ -1348,7 +1349,8 @@ func TestSqlStore_GetGroupsByIDs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveGroup(t *testing.T) {
|
||||
func TestSqlStore_CreateGroup(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.MysqlStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
@@ -1356,12 +1358,14 @@ func TestSqlStore_SaveGroup(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
group := &types.Group{
|
||||
ID: "group-id",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
ID: "group-id",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{},
|
||||
Resources: []types.Resource{},
|
||||
GroupPeers: []types.GroupPeer{},
|
||||
}
|
||||
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
err = store.CreateGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id")
|
||||
@@ -1369,7 +1373,7 @@ func TestSqlStore_SaveGroup(t *testing.T) {
|
||||
require.Equal(t, savedGroup, group)
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveGroups(t *testing.T) {
|
||||
func TestSqlStore_CreateUpdateGroups(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
@@ -1378,23 +1382,27 @@ func TestSqlStore_SaveGroups(t *testing.T) {
|
||||
|
||||
groups := []*types.Group{
|
||||
{
|
||||
ID: "group-1",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
ID: "group-1",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{},
|
||||
Resources: []types.Resource{},
|
||||
GroupPeers: []types.GroupPeer{},
|
||||
},
|
||||
{
|
||||
ID: "group-2",
|
||||
AccountID: accountID,
|
||||
Issued: "integration",
|
||||
Peers: []string{"peer3", "peer4"},
|
||||
ID: "group-2",
|
||||
AccountID: accountID,
|
||||
Issued: "integration",
|
||||
Peers: []string{},
|
||||
Resources: []types.Resource{},
|
||||
GroupPeers: []types.GroupPeer{},
|
||||
},
|
||||
}
|
||||
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
|
||||
err = store.CreateGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups[1].Peers = []string{}
|
||||
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
|
||||
err = store.UpdateGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
|
||||
require.NoError(t, err)
|
||||
|
||||
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groups[1].ID)
|
||||
@@ -2523,7 +2531,7 @@ func TestSqlStore_AddPeerToGroup(t *testing.T) {
|
||||
require.NoError(t, err, "failed to get group")
|
||||
require.Len(t, group.Peers, 0, "group should have 0 peers")
|
||||
|
||||
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID)
|
||||
err = store.AddPeerToGroup(context.Background(), accountID, peerID, groupID)
|
||||
require.NoError(t, err, "failed to add peer to group")
|
||||
|
||||
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
@@ -2554,7 +2562,7 @@ func TestSqlStore_AddPeerToAllGroup(t *testing.T) {
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer)
|
||||
require.NoError(t, err, "failed to add peer to account")
|
||||
|
||||
err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID)
|
||||
err = store.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
|
||||
require.NoError(t, err, "failed to add peer to all group")
|
||||
|
||||
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
@@ -2640,7 +2648,7 @@ func TestSqlStore_GetPeerGroups(t *testing.T) {
|
||||
assert.Len(t, groups, 1)
|
||||
assert.Equal(t, groups[0].Name, "All")
|
||||
|
||||
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, "cfefqs706sqkneg59g4h")
|
||||
err = store.AddPeerToGroup(context.Background(), accountID, peerID, "cfefqs706sqkneg59g4h")
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID)
|
||||
@@ -3307,7 +3315,7 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave)
|
||||
err = store.CreateGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
|
||||
@@ -3538,3 +3546,64 @@ func TestSqlStore_GetAnyAccountID(t *testing.T) {
|
||||
assert.Empty(t, accountID)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkGetAccountPeers(b *testing.B) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", b.TempDir())
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.Cleanup(cleanup)
|
||||
|
||||
numberOfPeers := 1000
|
||||
numberOfGroups := 200
|
||||
numberOfPeersPerGroup := 500
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
peers := make([]*nbpeer.Peer, 0, numberOfPeers)
|
||||
for i := 0; i < numberOfPeers; i++ {
|
||||
peer := &nbpeer.Peer{
|
||||
ID: fmt.Sprintf("peer-%d", i),
|
||||
AccountID: accountID,
|
||||
DNSLabel: fmt.Sprintf("peer%d.example.com", i),
|
||||
IP: intToIPv4(uint32(i)),
|
||||
}
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to add peer: %v", err)
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
|
||||
for i := 0; i < numberOfGroups; i++ {
|
||||
groupID := fmt.Sprintf("group-%d", i)
|
||||
group := &types.Group{
|
||||
ID: groupID,
|
||||
AccountID: accountID,
|
||||
}
|
||||
err = store.CreateGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create group: %v", err)
|
||||
}
|
||||
for j := 0; j < numberOfPeersPerGroup; j++ {
|
||||
peerIndex := (i*numberOfPeersPerGroup + j) % numberOfPeers
|
||||
err = store.AddPeerToGroup(context.Background(), accountID, peers[peerIndex].ID, groupID)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to add peer to group: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peers[i%numberOfPeers].ID)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func intToIPv4(n uint32) net.IP {
|
||||
ip := make(net.IP, 4)
|
||||
binary.BigEndian.PutUint32(ip, n)
|
||||
return ip
|
||||
}
|
||||
|
||||
@@ -101,8 +101,10 @@ type Store interface {
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
|
||||
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
|
||||
SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error
|
||||
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
|
||||
CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error
|
||||
UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error
|
||||
CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
|
||||
UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
|
||||
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
||||
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
||||
|
||||
@@ -120,9 +122,12 @@ type Store interface {
|
||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, accountID, peerId string, groupID string) error
|
||||
RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error
|
||||
RemovePeerFromAllGroups(ctx context.Context, peerID string) error
|
||||
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
|
||||
GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error)
|
||||
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
|
||||
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
|
||||
AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error
|
||||
@@ -196,6 +201,7 @@ type Store interface {
|
||||
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
|
||||
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
||||
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
|
||||
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -353,6 +359,15 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateJsonToTable[types.Group](ctx, db, "peers", func(accountID, id, value string) any {
|
||||
return &types.GroupPeer{
|
||||
AccountID: accountID,
|
||||
GroupID: id,
|
||||
PeerID: value,
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ type Account struct {
|
||||
Users map[string]*User `gorm:"-"`
|
||||
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Groups map[string]*Group `gorm:"-"`
|
||||
GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
|
||||
Routes map[route.ID]*route.Route `gorm:"-"`
|
||||
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
|
||||
@@ -26,7 +26,8 @@ type Group struct {
|
||||
Issued string
|
||||
|
||||
// Peers list of the group
|
||||
Peers []string `gorm:"serializer:json"`
|
||||
Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership
|
||||
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
|
||||
// Resources contains a list of resources in that group
|
||||
Resources []Resource `gorm:"serializer:json"`
|
||||
@@ -34,6 +35,32 @@ type Group struct {
|
||||
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
|
||||
}
|
||||
|
||||
type GroupPeer struct {
|
||||
AccountID string `gorm:"index"`
|
||||
GroupID string `gorm:"primaryKey"`
|
||||
PeerID string `gorm:"primaryKey"`
|
||||
}
|
||||
|
||||
func (g *Group) LoadGroupPeers() {
|
||||
g.Peers = make([]string, len(g.GroupPeers))
|
||||
for i, peer := range g.GroupPeers {
|
||||
g.Peers[i] = peer.PeerID
|
||||
}
|
||||
g.GroupPeers = []GroupPeer{}
|
||||
}
|
||||
|
||||
func (g *Group) StoreGroupPeers() {
|
||||
g.GroupPeers = make([]GroupPeer, len(g.Peers))
|
||||
for i, peer := range g.Peers {
|
||||
g.GroupPeers[i] = GroupPeer{
|
||||
AccountID: g.AccountID,
|
||||
GroupID: g.ID,
|
||||
PeerID: peer,
|
||||
}
|
||||
}
|
||||
g.Peers = []string{}
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to the group
|
||||
func (g *Group) EventMeta() map[string]any {
|
||||
return map[string]any{"name": g.Name}
|
||||
@@ -46,13 +73,16 @@ func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]an
|
||||
func (g *Group) Copy() *Group {
|
||||
group := &Group{
|
||||
ID: g.ID,
|
||||
AccountID: g.AccountID,
|
||||
Name: g.Name,
|
||||
Issued: g.Issued,
|
||||
Peers: make([]string, len(g.Peers)),
|
||||
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
|
||||
Resources: make([]Resource, len(g.Resources)),
|
||||
IntegrationReference: g.IntegrationReference,
|
||||
}
|
||||
copy(group.Peers, g.Peers)
|
||||
copy(group.GroupPeers, g.GroupPeers)
|
||||
copy(group.Resources, g.Resources)
|
||||
return group
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ type SetupKey struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
Key string
|
||||
KeySecret string
|
||||
KeySecret string `gorm:"index"`
|
||||
Name string
|
||||
Type SetupKeyType
|
||||
CreatedAt time.Time
|
||||
|
||||
@@ -677,13 +677,18 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
|
||||
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
|
||||
removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups)
|
||||
if err != nil {
|
||||
return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, updatedGroups); err != nil {
|
||||
return false, nil, nil, nil, fmt.Errorf("error saving groups: %w", err)
|
||||
addedGroups := util.Difference(update.AutoGroups, oldUser.AutoGroups)
|
||||
for _, peer := range userPeers {
|
||||
for _, groupID := range removedGroups {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil {
|
||||
return false, nil, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, groupID, err)
|
||||
}
|
||||
}
|
||||
for _, groupID := range addedGroups {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
|
||||
return false, nil, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, groupID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1137,93 +1142,6 @@ func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID str
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
||||
func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) {
|
||||
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
userPeerIDMap := make(map[string]struct{}, len(peers))
|
||||
for _, peer := range peers {
|
||||
userPeerIDMap[peer.ID] = struct{}{}
|
||||
}
|
||||
|
||||
for _, gid := range groupsToAdd {
|
||||
group, ok := accountGroups[gid]
|
||||
if !ok {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
if changed := addUserPeersToGroup(userPeerIDMap, group); changed {
|
||||
groupsToUpdate = append(groupsToUpdate, group)
|
||||
}
|
||||
}
|
||||
|
||||
for _, gid := range groupsToRemove {
|
||||
group, ok := accountGroups[gid]
|
||||
if !ok {
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
if changed := removeUserPeersFromGroup(userPeerIDMap, group); changed {
|
||||
groupsToUpdate = append(groupsToUpdate, group)
|
||||
}
|
||||
}
|
||||
|
||||
return groupsToUpdate, nil
|
||||
}
|
||||
|
||||
// addUserPeersToGroup adds the user's peers to the group.
|
||||
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) bool {
|
||||
groupPeers := make(map[string]struct{}, len(group.Peers))
|
||||
for _, pid := range group.Peers {
|
||||
groupPeers[pid] = struct{}{}
|
||||
}
|
||||
|
||||
changed := false
|
||||
for pid := range userPeerIDs {
|
||||
if _, exists := groupPeers[pid]; !exists {
|
||||
groupPeers[pid] = struct{}{}
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
group.Peers = make([]string, 0, len(groupPeers))
|
||||
for pid := range groupPeers {
|
||||
group.Peers = append(group.Peers, pid)
|
||||
}
|
||||
|
||||
if changed {
|
||||
group.Peers = make([]string, 0, len(groupPeers))
|
||||
for pid := range groupPeers {
|
||||
group.Peers = append(group.Peers, pid)
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// removeUserPeersFromGroup removes user's peers from the group.
|
||||
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) bool {
|
||||
// skip removing peers from group All
|
||||
if group.Name == "All" {
|
||||
return false
|
||||
}
|
||||
|
||||
updatedPeers := make([]string, 0, len(group.Peers))
|
||||
changed := false
|
||||
|
||||
for _, pid := range group.Peers {
|
||||
if _, owned := userPeerIDs[pid]; owned {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
updatedPeers = append(updatedPeers, pid)
|
||||
}
|
||||
|
||||
if changed {
|
||||
group.Peers = updatedPeers
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
|
||||
for _, user := range userData {
|
||||
if user.ID == userID {
|
||||
|
||||
@@ -1335,11 +1335,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
// account groups propagation is enabled
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
policy := &types.Policy{
|
||||
|
||||
Reference in New Issue
Block a user