[management] migrate group peers into seperate table (#4096)

This commit is contained in:
Pascal Fischer
2025-08-01 12:22:07 +02:00
committed by GitHub
parent 71bb09d870
commit 552dc60547
24 changed files with 1139 additions and 421 deletions

View File

@@ -1368,7 +1368,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil
}
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil {
if err = transaction.CreateGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
@@ -1382,28 +1382,22 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
groupsMap := make(map[string]*types.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId)
if err != nil {
return fmt.Errorf("error getting user peers: %w", err)
}
updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err)
}
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, updatedGroups); err != nil {
return fmt.Errorf("error saving groups: %w", err)
for _, peer := range peers {
for _, g := range addNewGroups {
if err := transaction.AddPeerToGroup(ctx, userAuth.AccountId, peer.ID, g); err != nil {
return fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, g, err)
}
}
for _, g := range removeOldGroups {
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, g); err != nil {
return fmt.Errorf("error removing peer %s from group %s: %w", peer.ID, g, err)
}
}
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil {
@@ -1971,53 +1965,56 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
// Returns true if any groups were modified, true if those updates affect peers and an error.
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, false, err
}
groupsMap := make(map[string]*types.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, false, err
}
groupsToUpdate := make(map[string]*types.Group)
accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, false, fmt.Errorf("error getting account group peers: %w", err)
}
accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, false, fmt.Errorf("error getting account groups: %w", err)
}
for _, group := range accountGroups {
if _, exists := accountGroupPeers[group.ID]; !exists {
accountGroupPeers[group.ID] = make(map[string]struct{})
}
}
updatedGroups := []string{}
for _, user := range users {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id)
if err != nil {
return false, false, err
}
updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, user.AutoGroups, nil)
if err != nil {
return false, false, err
}
for _, group := range updatedGroups {
groupsToUpdate[group.ID] = group
groupsMap[group.ID] = group
for _, peer := range userPeers {
for _, groupID := range user.AutoGroups {
if _, exists := accountGroupPeers[groupID]; !exists {
// we do not wanna create the groups here
log.WithContext(ctx).Warnf("group %s does not exist for user group propagation", groupID)
continue
}
if _, exists := accountGroupPeers[groupID][peer.ID]; exists {
continue
}
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
return false, false, fmt.Errorf("error adding peer %s to group %s: %w", peer.ID, groupID, err)
}
updatedGroups = append(updatedGroups, groupID)
}
}
}
if len(groupsToUpdate) == 0 {
return false, false, nil
}
peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, maps.Keys(groupsToUpdate))
peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, updatedGroups)
if err != nil {
return false, false, err
return false, false, fmt.Errorf("error checking if group changes affect peers: %w", err)
}
err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, maps.Values(groupsToUpdate))
if err != nil {
return false, false, err
}
return true, peersAffected, nil
return len(updatedGroups) > 0, peersAffected, nil
}

View File

@@ -62,8 +62,10 @@ type Manager interface {
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
SaveGroup(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group, create bool) error
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error

View File

@@ -1159,7 +1159,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
Name: "GroupA",
Peers: []string{},
}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1194,7 +1194,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
}()
group.Peers = []string{peer1.ID, peer2.ID, peer3.ID}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
if err := manager.UpdateGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1240,11 +1240,12 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
group := types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
AccountID: account.Id,
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1292,7 +1293,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
Name: "GroupA",
Peers: []string{peer1.ID, peer3.ID},
}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1343,11 +1344,11 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}, true)
})
require.NoError(t, err, "failed to save group")
@@ -1672,9 +1673,10 @@ func TestAccount_Copy(t *testing.T) {
},
Groups: map[string]*types.Group{
"group1": {
ID: "group1",
Peers: []string{"peer1"},
Resources: []types.Resource{},
ID: "group1",
Peers: []string{"peer1"},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
},
},
Policies: []*types.Policy{
@@ -2616,6 +2618,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
}
func TestAccount_SetJWTGroups(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -3360,7 +3363,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1))
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group1))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
require.NoError(t, err)
@@ -3382,7 +3385,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
t.Run("should update membership and account peers for used groups", func(t *testing.T) {
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2))
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group2))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
require.NoError(t, err)

View File

@@ -495,7 +495,7 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
func TestDNSAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -506,7 +506,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
Name: "GroupB",
Peers: []string{},
},
}, true)
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -562,11 +562,11 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
// Creating DNS settings with groups that have peers should update account peers and send peer update
t.Run("creating dns setting with used groups", func(t *testing.T) {
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}, true)
})
assert.NoError(t, err)
done := make(chan struct{})

View File

@@ -65,22 +65,144 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName)
}
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group, create bool) error {
// CreateGroup object of the peers
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}, create)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
if err := transaction.CreateGroup(ctx, store.LockingStrengthUpdate, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err)
}
for _, peerID := range newGroup.Peers {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
}
}
return nil
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// SaveGroups adds new groups to the account.
// UpdateGroup object of the peers
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID)
if err != nil {
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
}
peersToAdd := util.Difference(newGroup.Peers, oldGroup.Peers)
peersToRemove := util.Difference(oldGroup.Peers, newGroup.Peers)
for _, peerID := range peersToAdd {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
}
}
for _, peerID := range peersToRemove {
if err := transaction.RemovePeerFromGroup(ctx, peerID, newGroup.ID); err != nil {
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, newGroup.ID, err)
}
}
newGroup.AccountID = accountID
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, newGroup)
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// CreateGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error {
operation := operations.Create
if !create {
operation = operations.Update
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operation)
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -116,7 +238,65 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
return err
}
return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave)
return transaction.CreateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave)
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// UpdateGroups updates groups in the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var groupsToSave []*types.Group
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
groupsToSave = append(groupsToSave, newGroup)
groupIDs = append(groupIDs, newGroup.ID)
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.UpdateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave)
})
if err != nil {
return err
@@ -265,20 +445,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var group *types.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.AddPeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
@@ -288,7 +458,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err
}
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
return transaction.AddPeerToGroup(ctx, accountID, peerID, groupID)
})
if err != nil {
return err
@@ -329,7 +499,7 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return err
}
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group)
})
if err != nil {
return err
@@ -347,20 +517,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var group *types.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.RemovePeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
@@ -370,7 +530,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
return transaction.RemovePeerFromGroup(ctx, peerID, groupID)
})
if err != nil {
return err
@@ -411,7 +571,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return err
}
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group)
})
if err != nil {
return err

View File

@@ -2,14 +2,20 @@ package server
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/groups"
@@ -18,8 +24,10 @@ import (
"github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@@ -40,7 +48,8 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
}
for _, group := range account.Groups {
group.Issued = types.GroupIssuedIntegration
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
group.ID = uuid.New().String()
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group)
if err != nil {
t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration)
}
@@ -48,7 +57,8 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups {
group.Issued = types.GroupIssuedJWT
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
group.ID = uuid.New().String()
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group)
if err != nil {
t.Errorf("should allow to create %s groups", types.GroupIssuedJWT)
}
@@ -56,7 +66,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups {
group.Issued = types.GroupIssuedAPI
group.ID = ""
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group, true)
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, group)
if err == nil {
t.Errorf("should not create api group with the same name, %s", group.Name)
}
@@ -162,7 +172,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
}
}
err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups, true)
err = manager.CreateGroups(context.Background(), account.Id, groupAdminUserID, groups)
assert.NoError(t, err, "Failed to save test groups")
testCases := []struct {
@@ -382,13 +392,13 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
return nil, nil, err
}
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute, true)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2, true)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups, true)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies, true)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys, true)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers, true)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration, true)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
acc, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
@@ -400,7 +410,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
func TestGroupAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
g := []*types.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -426,8 +436,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
Name: "GroupE",
Peers: []string{peer2.ID},
},
}, true)
assert.NoError(t, err)
}
for _, group := range g {
err := manager.CreateGroup(context.Background(), account.Id, userID, group)
assert.NoError(t, err)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
@@ -442,11 +455,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -513,7 +526,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
})
// adding a group to policy
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
Rules: []*types.PolicyRule{
{
@@ -535,11 +548,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -604,11 +617,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupC",
Name: "GroupC",
Peers: []string{peer1.ID, peer3.ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -645,11 +658,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -672,11 +685,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupD",
Name: "GroupD",
Peers: []string{peer1.ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -719,11 +732,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupE",
Name: "GroupE",
Peers: []string{peer2.ID, peer3.ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -733,3 +746,259 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}
})
}
func Test_AddPeerToGroup(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
acc, err := createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
err = manager.Store.AddPeerToGroup(context.Background(), accountID, strconv.Itoa(i), acc.GroupsG[0].ID)
if err != nil {
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
}
func Test_AddPeerToAll(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
err = manager.Store.AddPeerToAllGroup(context.Background(), accountID, strconv.Itoa(i))
if err != nil {
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
}
func Test_AddPeerAndAddToAll(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
peer := &peer2.Peer{
ID: strconv.Itoa(i),
AccountID: accountID,
DNSLabel: "peer" + strconv.Itoa(i),
IP: uint32ToIP(uint32(i)),
}
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
if err != nil {
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
}
err = transaction.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
if err != nil {
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
}
return nil
})
if err != nil {
t.Errorf("AddPeer failed for peer %d: %v", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
}
func uint32ToIP(n uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, n)
return ip
}
func Test_IncrementNetworkSerial(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get account %s: %v", accountID, err)
}
return nil
})
if err != nil {
t.Errorf("AddPeer failed for peer %d: %v", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, int(account.Network.Serial), "Expected %d serial increases in account %s, got %d", totalPeers, accountID, account.Network.Serial)
}

View File

@@ -143,7 +143,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
IntegrationReference: existingGroup.IntegrationReference,
}
if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, false); err != nil {
if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
util.WriteError(r.Context(), err, w)
return
@@ -203,7 +203,7 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
Issued: types.GroupIssuedAPI,
}
err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group, true)
err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -39,6 +39,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f
return nil
}
if !db.Migrator().HasColumn(&model, fieldName) {
log.WithContext(ctx).Debugf("Table for %T does not have column %s, no migration needed", model, fieldName)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(model)
if err != nil {
@@ -422,3 +427,62 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName)
return nil
}
func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName string, mapperFunc func(accountID string, id string, value string) any) error {
var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
if !db.Migrator().HasColumn(&model, columnName) {
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no migration needed", columnName, tableName)
return nil
}
if err := db.Transaction(func(tx *gorm.DB) error {
var rows []map[string]any
if err := tx.Table(tableName).Select("id", "account_id", columnName).Find(&rows).Error; err != nil {
return fmt.Errorf("find rows: %w", err)
}
for _, row := range rows {
jsonValue, ok := row[columnName].(string)
if !ok || jsonValue == "" {
continue
}
var data []string
if err := json.Unmarshal([]byte(jsonValue), &data); err != nil {
return fmt.Errorf("unmarshal json: %w", err)
}
for _, value := range data {
if err := tx.Create(
mapperFunc(row["account_id"].(string), row["id"].(string), value),
).Error; err != nil {
return fmt.Errorf("failed to insert id %v: %w", row["id"], err)
}
}
}
if err := tx.Migrator().DropColumn(&model, columnName); err != nil {
return fmt.Errorf("drop column %s: %w", columnName, err)
}
return nil
}); err != nil {
return err
}
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
return nil
}

View File

@@ -124,6 +124,34 @@ type MockAccountManager struct {
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
}
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
if am.SaveGroupFunc != nil {
return am.SaveGroupFunc(ctx, accountID, userID, group, true)
}
return status.Errorf(codes.Unimplemented, "method CreateGroup is not implemented")
}
func (am *MockAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
if am.SaveGroupFunc != nil {
return am.SaveGroupFunc(ctx, accountID, userID, group, false)
}
return status.Errorf(codes.Unimplemented, "method UpdateGroup is not implemented")
}
func (am *MockAccountManager) CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error {
if am.SaveGroupsFunc != nil {
return am.SaveGroupsFunc(ctx, accountID, userID, newGroups, true)
}
return status.Errorf(codes.Unimplemented, "method CreateGroups is not implemented")
}
func (am *MockAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error {
if am.SaveGroupsFunc != nil {
return am.SaveGroupsFunc(ctx, accountID, userID, newGroups, false)
}
return status.Errorf(codes.Unimplemented, "method UpdateGroups is not implemented")
}
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
if am.UpdateAccountPeersFunc != nil {
am.UpdateAccountPeersFunc(ctx, accountID)

View File

@@ -980,18 +980,18 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
var newNameServerGroupA *nbdns.NameServerGroup
var newNameServerGroupB *nbdns.NameServerGroup
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
{
ID: "groupA",
Name: "GroupA",
Peers: []string{},
},
{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
},
}, true)
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{},
})
assert.NoError(t, err)
err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)

View File

@@ -374,12 +374,20 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
return fmt.Errorf("failed to remove peer from groups: %w", err)
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
return err
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return err
@@ -478,7 +486,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
var newPeer *nbpeer.Peer
var updateAccountPeers bool
var setupKeyID string
var setupKeyName string
@@ -615,20 +622,20 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return err
}
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if len(groupsToAdd) > 0 {
for _, g := range groupsToAdd {
err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g)
err = transaction.AddPeerToGroup(ctx, newPeer.AccountID, newPeer.ID, g)
if err != nil {
return err
}
}
}
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
@@ -678,7 +685,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
}
updateAccountPeers, err = isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
if err != nil {
updateAccountPeers = true
}
@@ -1021,7 +1028,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
}()
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -1523,17 +1530,7 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) {
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID)
if err != nil {
return nil, err
}
groupIDs := make([]string, 0, len(groups))
for _, group := range groups {
groupIDs = append(groupIDs, group.ID)
}
return groupIDs, err
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID)
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
@@ -1563,17 +1560,8 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
}
for _, peer := range peers {
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthUpdate, accountID, peer.ID)
if err != nil {
return nil, fmt.Errorf("failed to get peer groups: %w", err)
}
for _, group := range groups {
group.RemovePeer(peer.ID)
err = transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
if err != nil {
return nil, fmt.Errorf("failed to save group: %w", err)
}
if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
return nil, fmt.Errorf("failed to remove peer %s from groups", peer.ID)
}
if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID, settings.Extra); err != nil {

View File

@@ -310,12 +310,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
group1.Peers = append(group1.Peers, peer1.ID)
group2.Peers = append(group2.Peers, peer2.ID)
err = manager.SaveGroup(context.Background(), account.Id, userID, &group1, true)
err = manager.CreateGroup(context.Background(), account.Id, userID, &group1)
if err != nil {
t.Errorf("expecting group1 to be added, got failure %v", err)
return
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &group2, true)
err = manager.CreateGroup(context.Background(), account.Id, userID, &group2)
if err != nil {
t.Errorf("expecting group2 to be added, got failure %v", err)
return
@@ -1475,6 +1475,10 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
}
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
engine := os.Getenv("NETBIRD_STORE_ENGINE")
if engine == "sqlite" || engine == "" {
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
}
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
@@ -1709,7 +1713,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
require.NoError(t, err)
err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
g := []*types.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -1725,8 +1729,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
Name: "GroupC",
Peers: []string{},
},
}, true)
require.NoError(t, err)
}
for _, group := range g {
err = manager.CreateGroup(context.Background(), account.Id, userID, group)
require.NoError(t, err)
}
// create a user with auto groups
_, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*types.User{
@@ -1785,7 +1792,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg) //
close(done)
}()
@@ -2164,7 +2171,6 @@ func Test_IsUniqueConstraintError(t *testing.T) {
}
func Test_AddPeer(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
@@ -2176,7 +2182,7 @@ func Test_AddPeer(t *testing.T) {
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
t.Fatalf("error creating account: %v", err)
return
}
@@ -2186,22 +2192,21 @@ func Test_AddPeer(t *testing.T) {
return
}
const totalPeers = 300 // totalPeers / differentHostnames should be less than 10 (due to concurrent retries)
const differentHostnames = 50
const totalPeers = 300
var wg sync.WaitGroup
errs := make(chan error, totalPeers+differentHostnames)
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
hostNameID := i % differentHostnames
go func(i int) {
defer wg.Done()
newPeer := &nbpeer.Peer{
Key: "key" + strconv.Itoa(i),
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(hostNameID), GoOS: "linux"},
AccountID: accountID,
Key: "key" + strconv.Itoa(i),
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(i), GoOS: "linux"},
}
<-start

View File

@@ -993,7 +993,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int {
func TestPolicyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
g := []*types.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -1014,8 +1014,11 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Name: "GroupD",
Peers: []string{peer1.ID, peer2.ID},
},
}, true)
assert.NoError(t, err)
}
for _, group := range g {
err := manager.CreateGroup(context.Background(), account.Id, userID, group)
assert.NoError(t, err)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
@@ -1025,6 +1028,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
var policyWithGroupRulesNoPeers *types.Policy
var policyWithDestinationPeersOnly *types.Policy
var policyWithSourceAndDestinationPeers *types.Policy
var err error
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {

View File

@@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/posture"
@@ -105,10 +105,14 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
Id: regularUserID,
Role: types.UserRoleUser,
}
peer1 := &peer.Peer{
ID: "peer1",
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
account.Users[admin.Id] = admin
account.Users[user.Id] = user
account.Peers["peer1"] = peer1
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
@@ -121,7 +125,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
g := []*types.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -137,8 +141,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Name: "GroupC",
Peers: []string{},
},
}, true)
assert.NoError(t, err)
}
for _, group := range g {
err := manager.CreateGroup(context.Background(), account.Id, userID, group)
assert.NoError(t, err)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
@@ -156,7 +163,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
}
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true)
postureCheckA, err := manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA, true)
require.NoError(t, err)
postureCheckB := &posture.Checks{
@@ -449,14 +456,16 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
AccountID: account.Id,
Peers: []string{"peer1"},
}
err = manager.CreateGroup(context.Background(), account.Id, adminUserID, groupA)
require.NoError(t, err, "failed to create groupA")
groupB := &types.Group{
ID: "groupB",
AccountID: account.Id,
Peers: []string{},
}
err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, account.Id, []*types.Group{groupA, groupB})
require.NoError(t, err, "failed to save groups")
err = manager.CreateGroup(context.Background(), account.Id, adminUserID, groupB)
require.NoError(t, err, "failed to create groupB")
postureCheckA := &posture.Checks{
Name: "checkA",
@@ -535,7 +544,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
groupA.Peers = []string{}
err = manager.Store.SaveGroup(context.Background(), store.LockingStrengthUpdate, groupA)
err = manager.UpdateGroup(context.Background(), account.Id, adminUserID, groupA)
require.NoError(t, err, "failed to save groups")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)

View File

@@ -1215,7 +1215,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
Name: "peer1 group",
Peers: []string{peer1ID},
}
err = am.SaveGroup(context.Background(), account.Id, userID, newGroup, true)
err = am.CreateGroup(context.Background(), account.Id, userID, newGroup)
require.NoError(t, err)
rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser")
@@ -1505,7 +1505,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
}
for _, group := range newGroup {
err = am.SaveGroup(context.Background(), accountID, userID, group, true)
err = am.CreateGroup(context.Background(), accountID, userID, group)
if err != nil {
return nil, err
}
@@ -1953,7 +1953,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
account, err := initTestRouteAccount(t, manager)
require.NoError(t, err, "failed to init testing account")
err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
g := []*types.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -1969,8 +1969,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
Name: "GroupC",
Peers: []string{},
},
}, true)
assert.NoError(t, err)
}
for _, group := range g {
err = manager.CreateGroup(context.Background(), account.Id, userID, group)
require.NoError(t, err, "failed to create group %s", group.Name)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
t.Cleanup(func() {
@@ -2149,11 +2152,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1ID},
}, true)
})
assert.NoError(t, err)
select {
@@ -2189,11 +2192,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.UpdateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupC",
Name: "GroupC",
Peers: []string{peer1ID},
}, true)
})
assert.NoError(t, err)
select {

View File

@@ -29,7 +29,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
t.Fatal(err)
}
err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{
err = manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
{
ID: "group_1",
Name: "group_name_1",
@@ -40,7 +40,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
Name: "group_name_2",
Peers: []string{},
},
}, true)
})
if err != nil {
t.Fatal(err)
}
@@ -104,20 +104,20 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
}, true)
})
if err != nil {
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err = manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "group_2",
Name: "group_name_2",
Peers: []string{},
}, true)
})
if err != nil {
t.Fatal(err)
}
@@ -398,11 +398,11 @@ func TestSetupKey_Copy(t *testing.T) {
func TestSetupKeyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}, true)
})
assert.NoError(t, err)
policy := &types.Policy{

View File

@@ -96,7 +96,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePreAuto: %w", err)
}
err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
@@ -186,6 +186,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
generateAccountSQLTypes(account)
for _, group := range account.GroupsG {
group.StoreGroupPeers()
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
@@ -247,7 +251,8 @@ func generateAccountSQLTypes(account *types.Account) {
for id, group := range account.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
group.AccountID = account.Id
account.GroupsG = append(account.GroupsG, group)
}
for id, route := range account.Routes {
@@ -449,25 +454,56 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
return nil
}
// SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error {
// CreateGroups creates the given list of groups to the database.
func (s *SqlStore) CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error {
if len(groups) == 0 {
return nil
}
result := s.db.
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Create(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
return s.db.Transaction(func(tx *gorm.DB) error {
result := tx.
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Omit(clause.Associations).
Create(&groups)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save groups to store")
}
return nil
})
}
// UpdateGroups updates the given list of groups to the database.
func (s *SqlStore) UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error {
if len(groups) == 0 {
return nil
}
return nil
return s.db.Transaction(func(tx *gorm.DB) error {
result := tx.
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Omit(clause.Associations).
Create(&groups)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save groups to store")
}
return nil
})
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -646,7 +682,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
}
var groups []*types.Group
result := tx.Find(&groups, accountIDCondition, accountID)
result := tx.Preload(clause.Associations).Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -655,6 +691,10 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
}
for _, g := range groups {
g.LoadGroupPeers()
}
return groups, nil
}
@@ -669,6 +709,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
likePattern := `%"ID":"` + resourceID + `"%`
result := tx.
Preload(clause.Associations).
Where("resources LIKE ?", likePattern).
Find(&groups)
@@ -679,6 +720,10 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
return nil, result.Error
}
for _, g := range groups {
g.LoadGroupPeers()
}
return groups, nil
}
@@ -765,6 +810,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
var account types.Account
result := s.db.Model(&account).
Omit("GroupsG").
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations).
First(&account, idQueryCondition, accountID)
@@ -814,6 +860,17 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
}
account.GroupsG = nil
var groupPeers []types.GroupPeer
s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
Find(&groupPeers)
for _, groupPeer := range groupPeers {
if group, ok := account.Groups[groupPeer.GroupID]; ok {
group.Peers = append(group.Peers, groupPeer.PeerID)
} else {
log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
}
}
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy()
@@ -1311,55 +1368,76 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
}
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
var group types.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&group, "account_id = ? AND name = ?", accountID, "All")
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var groupID string
_ = s.db.Model(types.Group{}).
Select("id").
Where("account_id = ? AND name = ?", accountID, "All").
Limit(1).
Scan(&groupID)
if groupID == "" {
return status.Errorf(status.NotFound, "group 'All' not found for account %s", accountID)
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
err := s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true,
}).Create(&types.GroupPeer{
AccountID: accountID,
GroupID: groupID,
PeerID: peerID,
}).Error
group.Peers = append(group.Peers, peerID)
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
if err != nil {
return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err)
}
return nil
}
// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error {
var group types.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID).
First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewGroupNotFoundError(groupID)
}
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
// AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
peer := &types.GroupPeer{
AccountID: accountID,
GroupID: groupID,
PeerID: peerID,
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
err := s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true,
}).Create(peer).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to add peer %s to group %s for account %s: %v", peerID, groupID, accountID, err)
return status.Errorf(status.Internal, "failed to add peer to group")
}
group.Peers = append(group.Peers, peerId)
return nil
}
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group: %s", err)
// RemovePeerFromGroup removes a peer from a group
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
err := s.db.WithContext(ctx).
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err)
return status.Errorf(status.Internal, "failed to remove peer from group")
}
return nil
}
// RemovePeerFromAllGroups removes a peer from all groups
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
err := s.db.WithContext(ctx).
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err)
return status.Errorf(status.Internal, "failed to remove peer from all groups")
}
return nil
@@ -1427,15 +1505,46 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
var groups []*types.Group
query := tx.
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
Joins("JOIN group_peers ON group_peers.group_id = groups.id").
Where("group_peers.peer_id = ?", peerId).
Preload(clause.Associations).
Find(&groups)
if query.Error != nil {
return nil, query.Error
}
for _, group := range groups {
group.LoadGroupPeers()
}
return groups, nil
}
// GetPeerGroupIDs retrieves all group IDs assigned to a specific peer in a given account.
func (s *SqlStore) GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var groupIDs []string
query := tx.
Model(&types.GroupPeer{}).
Where("account_id = ? AND peer_id = ?", accountId, peerId).
Pluck("group_id", &groupIDs)
if query.Error != nil {
if errors.Is(query.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no groups found for peer %s in account %s", peerId, accountId)
}
log.WithContext(ctx).Errorf("failed to get group IDs for peer %s in account %s: %v", peerId, accountId, query.Error)
return nil, status.Errorf(status.Internal, "failed to get group IDs for peer from store")
}
return groupIDs, nil
}
// GetAccountPeers retrieves peers for an account.
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
@@ -1485,7 +1594,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error {
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil {
if err := s.db.Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -1722,7 +1831,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
}
var group *types.Group
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
result := tx.Preload(clause.Associations).First(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewGroupNotFoundError(groupID)
@@ -1731,15 +1840,14 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
return nil, status.Errorf(status.Internal, "failed to get group from store")
}
group.LoadGroupPeers()
return group, nil
}
// GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var group types.Group
@@ -1747,16 +1855,14 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// we may need to reconsider changing the types.
query := tx.Preload(clause.Associations)
switch s.storeEngine {
case types.PostgresStoreEngine:
query = query.Order("json_array_length(peers::json) DESC")
case types.MysqlStoreEngine:
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
default:
query = query.Order("json_array_length(peers) DESC")
}
result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
result := query.
Model(&types.Group{}).
Joins("LEFT JOIN group_peers ON group_peers.group_id = groups.id").
Where("groups.account_id = ? AND groups.name = ?", accountID, groupName).
Group("groups.id").
Order("COUNT(group_peers.peer_id) DESC").
Limit(1).
First(&group)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewGroupNotFoundError(groupName)
@@ -1764,6 +1870,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
}
group.LoadGroupPeers()
return &group, nil
}
@@ -1775,7 +1884,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
}
var groups []*types.Group
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
result := tx.Preload(clause.Associations).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
@@ -1783,25 +1892,45 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
groupsMap := make(map[string]*types.Group)
for _, group := range groups {
group.LoadGroupPeers()
groupsMap[group.ID] = group
}
return groupsMap, nil
}
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
// CreateGroup creates a group in the store.
func (s *SqlStore) CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
if group == nil {
return status.Errorf(status.InvalidArgument, "group is nil")
}
if err := s.db.Omit(clause.Associations).Create(group).Error; err != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
return status.Errorf(status.Internal, "failed to save group to store")
}
return nil
}
// UpdateGroup updates a group in the store.
func (s *SqlStore) UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
if group == nil {
return status.Errorf(status.InvalidArgument, "group is nil")
}
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
return status.Errorf(status.Internal, "failed to save group to store")
}
return nil
}
// DeleteGroup deletes a group from the database.
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Select(clause.Associations).
Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
@@ -1818,6 +1947,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength
// DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
Select(clause.Associations).
Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
@@ -2613,3 +2743,27 @@ func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain stri
return count, nil
}
func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peers []types.GroupPeer
result := tx.Find(&peers, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get account group peers from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get account group peers from store")
}
groupPeers := make(map[string]map[string]struct{})
for _, peer := range peers {
if _, exists := groupPeers[peer.GroupID]; !exists {
groupPeers[peer.GroupID] = make(map[string]struct{})
}
groupPeers[peer.GroupID][peer.PeerID] = struct{}{}
}
return groupPeers, nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"encoding/binary"
"fmt"
"math/rand"
"net"
@@ -1187,7 +1188,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
Peers: nil,
}
err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
err := transaction.CreateGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
t.Fatal("failed to save group")
return err
@@ -1348,7 +1349,8 @@ func TestSqlStore_GetGroupsByIDs(t *testing.T) {
}
}
func TestSqlStore_SaveGroup(t *testing.T) {
func TestSqlStore_CreateGroup(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(types.MysqlStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@@ -1356,12 +1358,14 @@ func TestSqlStore_SaveGroup(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group := &types.Group{
ID: "group-id",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
ID: "group-id",
AccountID: accountID,
Issued: "api",
Peers: []string{},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
}
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
err = store.CreateGroup(context.Background(), LockingStrengthUpdate, group)
require.NoError(t, err)
savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id")
@@ -1369,7 +1373,7 @@ func TestSqlStore_SaveGroup(t *testing.T) {
require.Equal(t, savedGroup, group)
}
func TestSqlStore_SaveGroups(t *testing.T) {
func TestSqlStore_CreateUpdateGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@@ -1378,23 +1382,27 @@ func TestSqlStore_SaveGroups(t *testing.T) {
groups := []*types.Group{
{
ID: "group-1",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
ID: "group-1",
AccountID: accountID,
Issued: "api",
Peers: []string{},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
},
{
ID: "group-2",
AccountID: accountID,
Issued: "integration",
Peers: []string{"peer3", "peer4"},
ID: "group-2",
AccountID: accountID,
Issued: "integration",
Peers: []string{},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
},
}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
err = store.CreateGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
require.NoError(t, err)
groups[1].Peers = []string{}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
err = store.UpdateGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
require.NoError(t, err)
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groups[1].ID)
@@ -2523,7 +2531,7 @@ func TestSqlStore_AddPeerToGroup(t *testing.T) {
require.NoError(t, err, "failed to get group")
require.Len(t, group.Peers, 0, "group should have 0 peers")
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID)
err = store.AddPeerToGroup(context.Background(), accountID, peerID, groupID)
require.NoError(t, err, "failed to add peer to group")
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
@@ -2554,7 +2562,7 @@ func TestSqlStore_AddPeerToAllGroup(t *testing.T) {
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer)
require.NoError(t, err, "failed to add peer to account")
err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID)
err = store.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
require.NoError(t, err, "failed to add peer to all group")
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
@@ -2640,7 +2648,7 @@ func TestSqlStore_GetPeerGroups(t *testing.T) {
assert.Len(t, groups, 1)
assert.Equal(t, groups[0].Name, "All")
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, "cfefqs706sqkneg59g4h")
err = store.AddPeerToGroup(context.Background(), accountID, peerID, "cfefqs706sqkneg59g4h")
require.NoError(t, err)
groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID)
@@ -3307,7 +3315,7 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
})
}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave)
err = store.CreateGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave)
require.NoError(t, err)
accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
@@ -3538,3 +3546,64 @@ func TestSqlStore_GetAnyAccountID(t *testing.T) {
assert.Empty(t, accountID)
})
}
func BenchmarkGetAccountPeers(b *testing.B) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", b.TempDir())
if err != nil {
b.Fatal(err)
}
b.Cleanup(cleanup)
numberOfPeers := 1000
numberOfGroups := 200
numberOfPeersPerGroup := 500
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
peers := make([]*nbpeer.Peer, 0, numberOfPeers)
for i := 0; i < numberOfPeers; i++ {
peer := &nbpeer.Peer{
ID: fmt.Sprintf("peer-%d", i),
AccountID: accountID,
DNSLabel: fmt.Sprintf("peer%d.example.com", i),
IP: intToIPv4(uint32(i)),
}
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer)
if err != nil {
b.Fatalf("Failed to add peer: %v", err)
}
peers = append(peers, peer)
}
for i := 0; i < numberOfGroups; i++ {
groupID := fmt.Sprintf("group-%d", i)
group := &types.Group{
ID: groupID,
AccountID: accountID,
}
err = store.CreateGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
b.Fatalf("Failed to create group: %v", err)
}
for j := 0; j < numberOfPeersPerGroup; j++ {
peerIndex := (i*numberOfPeersPerGroup + j) % numberOfPeers
err = store.AddPeerToGroup(context.Background(), accountID, peers[peerIndex].ID, groupID)
if err != nil {
b.Fatalf("Failed to add peer to group: %v", err)
}
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peers[i%numberOfPeers].ID)
if err != nil {
b.Fatal(err)
}
}
}
func intToIPv4(n uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, n)
return ip
}

View File

@@ -101,8 +101,10 @@ type Store interface {
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error
UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error
CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
@@ -120,9 +122,12 @@ type Store interface {
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountID, peerId string, groupID string) error
RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error
RemovePeerFromAllGroups(ctx context.Context, peerID string) error
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error)
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error
@@ -196,6 +201,7 @@ type Store interface {
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)
}
const (
@@ -353,6 +359,15 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label")
},
func(db *gorm.DB) error {
return migration.MigrateJsonToTable[types.Group](ctx, db, "peers", func(accountID, id, value string) any {
return &types.GroupPeer{
AccountID: accountID,
GroupID: id,
PeerID: value,
}
})
},
}
}

View File

@@ -73,7 +73,7 @@ type Account struct {
Users map[string]*User `gorm:"-"`
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
Groups map[string]*Group `gorm:"-"`
GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
Routes map[route.ID]*route.Route `gorm:"-"`
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`

View File

@@ -26,7 +26,8 @@ type Group struct {
Issued string
// Peers list of the group
Peers []string `gorm:"serializer:json"`
Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
// Resources contains a list of resources in that group
Resources []Resource `gorm:"serializer:json"`
@@ -34,6 +35,32 @@ type Group struct {
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
}
type GroupPeer struct {
AccountID string `gorm:"index"`
GroupID string `gorm:"primaryKey"`
PeerID string `gorm:"primaryKey"`
}
func (g *Group) LoadGroupPeers() {
g.Peers = make([]string, len(g.GroupPeers))
for i, peer := range g.GroupPeers {
g.Peers[i] = peer.PeerID
}
g.GroupPeers = []GroupPeer{}
}
func (g *Group) StoreGroupPeers() {
g.GroupPeers = make([]GroupPeer, len(g.Peers))
for i, peer := range g.Peers {
g.GroupPeers[i] = GroupPeer{
AccountID: g.AccountID,
GroupID: g.ID,
PeerID: peer,
}
}
g.Peers = []string{}
}
// EventMeta returns activity event meta related to the group
func (g *Group) EventMeta() map[string]any {
return map[string]any{"name": g.Name}
@@ -46,13 +73,16 @@ func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]an
func (g *Group) Copy() *Group {
group := &Group{
ID: g.ID,
AccountID: g.AccountID,
Name: g.Name,
Issued: g.Issued,
Peers: make([]string, len(g.Peers)),
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
Resources: make([]Resource, len(g.Resources)),
IntegrationReference: g.IntegrationReference,
}
copy(group.Peers, g.Peers)
copy(group.GroupPeers, g.GroupPeers)
copy(group.Resources, g.Resources)
return group
}

View File

@@ -35,7 +35,7 @@ type SetupKey struct {
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
Key string
KeySecret string
KeySecret string `gorm:"index"`
Name string
Type SetupKeyType
CreatedAt time.Time

View File

@@ -677,13 +677,18 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups)
updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups)
if err != nil {
return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err)
}
if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, updatedGroups); err != nil {
return false, nil, nil, nil, fmt.Errorf("error saving groups: %w", err)
addedGroups := util.Difference(update.AutoGroups, oldUser.AutoGroups)
for _, peer := range userPeers {
for _, groupID := range removedGroups {
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil {
return false, nil, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, groupID, err)
}
}
for _, groupID := range addedGroups {
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
return false, nil, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, groupID, err)
}
}
}
}
@@ -1137,93 +1142,6 @@ func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID str
return userInfo, nil
}
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, groupsToRemove []string) (groupsToUpdate []*types.Group, err error) {
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return
}
userPeerIDMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
userPeerIDMap[peer.ID] = struct{}{}
}
for _, gid := range groupsToAdd {
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
if changed := addUserPeersToGroup(userPeerIDMap, group); changed {
groupsToUpdate = append(groupsToUpdate, group)
}
}
for _, gid := range groupsToRemove {
group, ok := accountGroups[gid]
if !ok {
return nil, errors.New("group not found")
}
if changed := removeUserPeersFromGroup(userPeerIDMap, group); changed {
groupsToUpdate = append(groupsToUpdate, group)
}
}
return groupsToUpdate, nil
}
// addUserPeersToGroup adds the user's peers to the group.
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) bool {
groupPeers := make(map[string]struct{}, len(group.Peers))
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
}
changed := false
for pid := range userPeerIDs {
if _, exists := groupPeers[pid]; !exists {
groupPeers[pid] = struct{}{}
changed = true
}
}
group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
if changed {
group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
}
return changed
}
// removeUserPeersFromGroup removes user's peers from the group.
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) bool {
// skip removing peers from group All
if group.Name == "All" {
return false
}
updatedPeers := make([]string, 0, len(group.Peers))
changed := false
for _, pid := range group.Peers {
if _, owned := userPeerIDs[pid]; owned {
changed = true
continue
}
updatedPeers = append(updatedPeers, pid)
}
if changed {
group.Peers = updatedPeers
}
return changed
}
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData {
if user.ID == userID {

View File

@@ -1335,11 +1335,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
// account groups propagation is enabled
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}, true)
})
require.NoError(t, err)
policy := &types.Policy{