Compare commits

...

35 Commits

Author SHA1 Message Date
Pascal Fischer
261d1e094a add debug logging to getAccount 2025-07-04 17:46:12 +02:00
Pascal Fischer
bcab5cbbee fix migration check 2025-07-04 15:23:03 +02:00
Pascal Fischer
3931958499 disable avoid adding networks multiple times 2025-07-04 14:19:13 +02:00
Pascal Fischer
914e58ac75 Merge branch 'feature/net-266-groups-migration' into feature/net-297-network-migration 2025-07-04 14:18:18 +02:00
Pascal Fischer
7baeea3d9d disable avoid adding group peers multiple times 2025-07-04 14:17:34 +02:00
Pascal Fischer
64e618f1ad disable migration cleanup 2025-07-04 13:14:47 +02:00
Pascal Fischer
5802dcaf80 Merge branch 'feature/net-266-groups-migration' into feature/net-297-network-migration
# Conflicts:
#	management/server/migration/migration.go
#	management/server/peer_test.go
2025-07-04 13:09:01 +02:00
Pascal Fischer
b329397d06 Revert "use FullSaveAssociations for single group save"
This reverts commit 76b180a741.
2025-07-04 10:16:41 +02:00
Pascal Fischer
76b180a741 use FullSaveAssociations for single group save 2025-07-04 00:00:40 +02:00
Pascal Fischer
8db9065ed9 run account update buffer in goroutine 2025-07-03 20:36:35 +02:00
Pascal Fischer
1f79fc0728 ignore associations on create 2025-07-03 19:58:11 +02:00
Pascal Fischer
bdd0b1cf02 update get group by name 2025-07-03 19:24:12 +02:00
Pascal Fischer
e6ac248aee load peers on retrival by name 2025-07-03 19:11:18 +02:00
Pascal Fischer
376394f7f9 fix large batch 2025-07-03 19:00:15 +02:00
Pascal Fischer
542dbdb41c fix group save with removal of all peers 2025-07-03 18:50:17 +02:00
Pascal Fischer
982b9604ee cleanup 2025-07-03 18:41:49 +02:00
Pascal Fischer
f2990e2fbc cleanup 2025-07-03 18:10:00 +02:00
Pascal Fischer
dfb47d5545 fix tests 2025-07-03 17:20:21 +02:00
Pascal Fischer
8e0b8f20a2 fix tests 2025-07-03 16:56:47 +02:00
Pascal Fischer
8a42528664 change getGroupsByPeers 2025-07-03 16:24:00 +02:00
Pascal Fischer
a8cba921e1 fix tests and group copy 2025-07-03 16:04:33 +02:00
Pascal Fischer
fee36b0663 fix peer add if group not existing 2025-07-03 15:40:28 +02:00
Pascal Fischer
dfad334780 fix peer add if group not existing 2025-07-03 15:30:35 +02:00
Pascal Fischer
d25da87957 don't fail on conflict 2025-07-03 15:15:48 +02:00
Pascal Fischer
13213d954d fix save account 2025-07-03 14:49:12 +02:00
Pascal Fischer
6fb61c7cf5 remove group peers from extended store 2025-07-03 13:45:59 +02:00
Pascal Fischer
459db2ba4f fix log in test 2025-07-03 13:41:36 +02:00
Pascal Fischer
e78b7dd058 fis association replace 2025-07-03 13:35:18 +02:00
Pascal Fischer
7132642e4c cleanup tests 2025-07-03 13:23:14 +02:00
Pascal Fischer
22a944b157 cleanup tests 2025-07-03 13:15:56 +02:00
Pascal Fischer
005937ae77 Update management/server/migration/migration.go
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-03 12:56:39 +02:00
Pascal Fischer
5fab2d019a set lock level on update 2025-07-03 12:55:11 +02:00
Pascal Fischer
36155f8de1 disable delete on migration 2025-07-03 12:51:55 +02:00
Pascal Fischer
d06831dd2f add network migration 2025-07-03 11:20:16 +02:00
Pascal Fischer
e23282b92c add groups migration 2025-07-02 19:09:42 +02:00
16 changed files with 736 additions and 190 deletions

View File

@@ -1688,7 +1688,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
log.WithContext(ctx).Debugf("creating new account")
network := types.NewNetwork()
network := types.NewNetwork(accountID)
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[route.ID]*route.Route)
@@ -1792,7 +1792,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
continue
}
network := types.NewNetwork()
network := types.NewNetwork(accountId)
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[route.ID]*route.Route)

View File

@@ -1240,9 +1240,10 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
group := types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
AccountID: account.Id,
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
t.Errorf("save group: %v", err)
@@ -1672,9 +1673,10 @@ func TestAccount_Copy(t *testing.T) {
},
Groups: map[string]*types.Group{
"group1": {
ID: "group1",
Peers: []string{"peer1"},
Resources: []types.Resource{},
ID: "group1",
Peers: []string{"peer1"},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
},
},
Policies: []*types.Policy{
@@ -2616,6 +2618,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
}
func TestAccount_SetJWTGroups(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

View File

@@ -265,20 +265,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var group *types.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.AddPeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
@@ -288,7 +278,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err
}
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
return transaction.AddPeerToGroup(ctx, peerID, groupID)
})
if err != nil {
return err
@@ -347,20 +337,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var group *types.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
if updated := group.RemovePeer(peerID); !updated {
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
@@ -370,7 +350,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
return transaction.RemovePeerFromGroup(ctx, peerID, groupID)
})
if err != nil {
return err

View File

@@ -2,14 +2,19 @@ package server
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/groups"
@@ -18,8 +23,10 @@ import (
"github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@@ -733,3 +740,259 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}
})
}
func Test_AddPeerToGroup(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
acc, err := createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
err = manager.Store.AddPeerToGroup(context.Background(), strconv.Itoa(i), acc.GroupsG[0].ID)
if err != nil {
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
}
func Test_AddPeerToAll(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
err = manager.Store.AddPeerToAllGroup(context.Background(), accountID, strconv.Itoa(i))
if err != nil {
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
}
func Test_AddPeerAndAddToAll(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
peer := &peer2.Peer{
ID: strconv.Itoa(i),
AccountID: accountID,
DNSLabel: "peer" + strconv.Itoa(i),
IP: uint32ToIP(uint32(i)),
}
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
if err != nil {
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
}
err = transaction.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
if err != nil {
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
}
return nil
})
if err != nil {
t.Errorf("AddPeer failed for peer %d: %v", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
}
func uint32ToIP(n uint32) net.IP {
ip := make(net.IP, 4)
binary.BigEndian.PutUint32(ip, n)
return ip
}
func Test_IncrementNetworkSerial(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
accountID := "testaccount"
userID := "testuser"
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
return
}
const totalPeers = 1000
var wg sync.WaitGroup
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
<-start
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get account %s: %v", accountID, err)
}
return nil
})
if err != nil {
t.Errorf("AddPeer failed for peer %d: %v", i, err)
return
}
}(i)
}
startTime := time.Now()
close(start)
wg.Wait()
close(errs)
t.Logf("time since start: %s", time.Since(startTime))
for err := range errs {
t.Fatal(err)
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Fatalf("Failed to get account %s: %v", accountID, err)
}
assert.Equal(t, totalPeers, int(account.Network.Serial), "Expected %d serial increases in account %s, got %d", totalPeers, accountID, account.Network.Serial)
}

View File

@@ -69,7 +69,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
handler := initAccountsTestData(t, &types.Account{
Id: accountID,
Domain: "hotmail.com",
Network: types.NewNetwork(),
Network: types.NewNetwork(accountID),
Users: map[string]*types.User{
adminUser.Id: adminUser,
},

View File

@@ -10,13 +10,25 @@ import (
"errors"
"fmt"
"net"
"reflect"
"strings"
"unicode/utf8"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type LegacyAccountNetwork struct {
AccountID string `gorm:"column:id"`
Identifier string `gorm:"column:network_identifier"`
Net net.IPNet `gorm:"column:network_net;serializer:json"`
Dns string `gorm:"column:network_dns"`
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps.
Serial uint64 `gorm:"column:network_serial"`
}
func GetColumnName(db *gorm.DB, column string) string {
if db.Name() == "mysql" {
return fmt.Sprintf("`%s`", column)
@@ -39,6 +51,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f
return nil
}
if !db.Migrator().HasColumn(&model, oldColumnName) {
log.WithContext(ctx).Debugf("Column for %T does not exist, no migration needed", oldColumnName)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(model)
if err != nil {
@@ -412,3 +429,149 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName)
return nil
}
func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName string, mapperFunc func(id string, value string) any) error {
var model T
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
if !db.Migrator().HasColumn(&model, columnName) {
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no migration needed", columnName, tableName)
return nil
}
if err := db.Transaction(func(tx *gorm.DB) error {
var rows []map[string]any
if err := tx.Table(tableName).Select("id", columnName).Find(&rows).Error; err != nil {
return fmt.Errorf("find rows: %w", err)
}
for _, row := range rows {
jsonValue, ok := row[columnName].(string)
if !ok || jsonValue == "" {
continue
}
var data []string
if err := json.Unmarshal([]byte(jsonValue), &data); err != nil {
return fmt.Errorf("unmarshal json: %w", err)
}
for _, value := range data {
if err := tx.Clauses(clause.OnConflict{
DoNothing: true, // this needs to be removed when the cleanup is enabled
}).Create(
mapperFunc(row["id"].(string), value),
).Error; err != nil {
return fmt.Errorf("failed to insert id %v: %w", row["id"], err)
}
}
}
// Todo: Enable this after we are sure that every thing works as expected and we do not need to rollback anymore
// if err := tx.Migrator().DropColumn(&model, columnName); err != nil {
// return fmt.Errorf("drop column %s: %w", columnName, err)
// }
return nil
}); err != nil {
return err
}
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
return nil
}
func MigrateEmbeddedToTable[T any, S any, U any](ctx context.Context, db *gorm.DB, pkey string, mapperFunc func(obj S) *U) error {
var model T
log.WithContext(ctx).Debugf("Migrating embedded fields from %T to separate table", model)
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&model)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
if err := db.Transaction(func(tx *gorm.DB) error {
var legacyRows []S
if err := tx.Table(tableName).Find(&legacyRows).Error; err != nil {
log.WithContext(ctx).Errorf("Failed to read legacy accounts: %v", err)
return fmt.Errorf("failed to read legacy accounts: %w", err)
}
for _, row := range legacyRows {
if err := tx.Clauses(clause.OnConflict{
DoNothing: true, // this needs to be removed when the cleanup is enabled
}).Create(
mapperFunc(row),
).Error; err != nil {
return fmt.Errorf("failed to insert id %v: %w", row, err)
}
}
// cols, err := getColumnNamesFromStruct(new(S))
// if err != nil {
// return fmt.Errorf("failed to extract column names: %w", err)
// }
// for _, col := range cols {
// if col == pkey {
// continue
// }
// if err := tx.Migrator().DropColumn(&model, col); err != nil {
// return fmt.Errorf("failed to drop column %s: %w", col, err)
// }
// }
return nil
}); err != nil {
return err
}
log.WithContext(ctx).Infof("Migration of embedded fields %T from table %s into seperte table completed", new(S), tableName)
return nil
}
func getColumnNamesFromStruct[T any](model T) ([]string, error) {
val := reflect.TypeOf(model)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
var cols []string
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
if field.Name == "ID" {
continue // skip primary key
}
tag := field.Tag.Get("gorm")
if tag == "" {
continue
}
// Look for gorm:"column:..."
for _, part := range strings.Split(tag, ";") {
if strings.HasPrefix(part, "column:") {
cols = append(cols, strings.TrimPrefix(part, "column:"))
break
}
}
}
return cols, nil
}

View File

@@ -360,25 +360,20 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("failed to get peer groups: %w", err)
}
for _, group := range groups {
group.RemovePeer(peerID)
err = transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
if err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
return fmt.Errorf("failed to remove peer from groups: %w", err)
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
return err
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return err
@@ -477,7 +472,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
var newPeer *nbpeer.Peer
var updateAccountPeers bool
var setupKeyID string
var setupKeyName string
@@ -594,12 +588,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
newPeer.DNSLabel = freeLabel
newPeer.IP = freeIP
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() {
if unlock != nil {
unlock()
}
}()
// unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
// defer func() {
// if unlock != nil {
// unlock()
// }
// }()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
@@ -607,20 +601,20 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return err
}
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if len(groupsToAdd) > 0 {
for _, g := range groupsToAdd {
err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g)
err = transaction.AddPeerToGroup(ctx, newPeer.ID, g)
if err != nil {
return err
}
}
}
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
if err != nil {
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
@@ -652,14 +646,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil
})
if err == nil {
unlock()
unlock = nil
// unlock()
// unlock = nil
break
}
if isUniqueConstraintError(err) {
unlock()
unlock = nil
// unlock()
// unlock = nil
log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err)
continue
}
@@ -670,11 +664,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
}
updateAccountPeers, err = isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
if err != nil {
updateAccountPeers = true
}
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
}
@@ -687,9 +676,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if updateAccountPeers {
am.BufferUpdateAccountPeers(ctx, accountID)
}
am.BufferUpdateAccountPeers(ctx, accountID)
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
}
@@ -1019,7 +1006,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
}()
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -1268,17 +1255,19 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
lock := mu.(*sync.Mutex)
if !lock.TryLock() {
return
}
go func() {
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
lock.Unlock()
am.UpdateAccountPeers(ctx, accountID)
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
lock := mu.(*sync.Mutex)
if !lock.TryLock() {
return
}
go func() {
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
lock.Unlock()
am.UpdateAccountPeers(ctx, accountID)
}()
}()
}

View File

@@ -1459,6 +1459,10 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
}
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
engine := os.Getenv("NETBIRD_STORE_ENGINE")
if engine == "sqlite" || engine == "" {
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
}
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
@@ -1764,7 +1768,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg) //
close(done)
}()
@@ -2144,6 +2148,8 @@ func Test_IsUniqueConstraintError(t *testing.T) {
func Test_AddPeer(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
t.Setenv("NB_GET_ACCOUNT_BUFFER_INTERVAL", "300ms")
t.Setenv("NB_PEER_UPDATE_BUFFER_INTERVAL", "300ms")
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
@@ -2155,7 +2161,7 @@ func Test_AddPeer(t *testing.T) {
_, err = createAccount(manager, accountID, userID, "domain.com")
if err != nil {
t.Fatal("error creating account")
t.Fatalf("error creating account: %v", err)
return
}
@@ -2165,22 +2171,21 @@ func Test_AddPeer(t *testing.T) {
return
}
const totalPeers = 300 // totalPeers / differentHostnames should be less than 10 (due to concurrent retries)
const differentHostnames = 50
const totalPeers = 300
var wg sync.WaitGroup
errs := make(chan error, totalPeers+differentHostnames)
errs := make(chan error, totalPeers)
start := make(chan struct{})
for i := 0; i < totalPeers; i++ {
wg.Add(1)
hostNameID := i % differentHostnames
go func(i int) {
defer wg.Done()
newPeer := &nbpeer.Peer{
Key: "key" + strconv.Itoa(i),
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(hostNameID), GoOS: "linux"},
AccountID: accountID,
Key: "key" + strconv.Itoa(i),
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(i), GoOS: "linux"},
}
<-start

View File

@@ -96,7 +96,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePreAuto: %w", err)
}
err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.Network{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
@@ -186,6 +186,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
generateAccountSQLTypes(account)
for _, group := range account.GroupsG {
group.StoreGroupPeers()
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
@@ -247,7 +251,7 @@ func generateAccountSQLTypes(account *types.Account) {
for id, group := range account.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
account.GroupsG = append(account.GroupsG, group)
}
for id, route := range account.Routes {
@@ -455,19 +459,40 @@ func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength,
return nil
}
result := s.db.
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Create(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
for _, g := range groups {
g.StoreGroupPeers()
}
return nil
return s.db.Transaction(func(tx *gorm.DB) error {
result := tx.
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Create(&groups)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save groups to store")
}
for _, g := range groups {
if len(g.GroupPeers) == 0 {
if err := tx.Where("group_id = ?", g.ID).Delete(&types.GroupPeer{}).Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group peers for group %s: %s", g.ID, err)
return status.Errorf(status.Internal, "failed to delete group peers")
}
} else {
if err := tx.Model(&g).Association("GroupPeers").Replace(g.GroupPeers); err != nil {
return status.Errorf(status.Internal, "failed to save group peers: %s", err)
}
}
}
return nil
})
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -646,7 +671,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
}
var groups []*types.Group
result := tx.Find(&groups, accountIDCondition, accountID)
result := tx.Preload(clause.Associations).Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -655,6 +680,10 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
}
for _, g := range groups {
g.LoadGroupPeers()
}
return groups, nil
}
@@ -669,6 +698,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
likePattern := `%"ID":"` + resourceID + `"%`
result := tx.
Preload(clause.Associations).
Where("resources LIKE ?", likePattern).
Find(&groups)
@@ -679,6 +709,10 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
return nil, result.Error
}
for _, g := range groups {
g.LoadGroupPeers()
}
return groups, nil
}
@@ -738,8 +772,9 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
}()
var account types.Account
result := s.db.Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
result := s.db.Session(&gorm.Session{Logger: logger.Default.LogMode(logger.Info)}).Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload("GroupsG.GroupPeers"). // have to be specifies as this is nester reference
Preload(clause.Associations).
First(&account, idQueryCondition, accountID)
if result.Error != nil {
@@ -753,7 +788,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
for i, policy := range account.Policies {
var rules []*types.PolicyRule
err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
err := s.db.Session(&gorm.Session{Logger: logger.Default.LogMode(logger.Info)}).Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
if err != nil {
return nil, status.Errorf(status.NotFound, "rule not found")
}
@@ -784,6 +819,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
group.LoadGroupPeers()
account.Groups[group.ID] = group.Copy()
}
account.GroupsG = nil
@@ -998,14 +1034,14 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountNetwork types.AccountNetwork
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
accountNetwork := types.Network{}
if err := tx.Where(accountIDCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
}
return accountNetwork.Network, nil
return &accountNetwork, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
@@ -1285,55 +1321,74 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
}
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
var group types.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&group, "account_id = ? AND name = ?", accountID, "All")
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var groupID string
_ = s.db.Model(types.Group{}).
Select("id").
Where("account_id = ? AND name = ?", accountID, "All").
Limit(1).
Scan(&groupID)
if groupID == "" {
return status.Errorf(status.NotFound, "group 'All' not found for account %s", accountID)
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
err := s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true,
}).Create(&types.GroupPeer{
GroupID: groupID,
PeerID: peerID,
}).Error
group.Peers = append(group.Peers, peerID)
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
if err != nil {
return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err)
}
return nil
}
// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error {
var group types.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID).
First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewGroupNotFoundError(groupID)
}
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
// AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, peerID string, groupID string) error {
peer := &types.GroupPeer{
GroupID: groupID,
PeerID: peerID,
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true,
}).Create(peer).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to add peer %s to group %s: %v", peerID, groupID, err)
return status.Errorf(status.Internal, "failed to add peer to group")
}
group.Peers = append(group.Peers, peerId)
return nil
}
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group: %s", err)
// RemovePeerFromGroup removes a peer from a group
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
err := s.db.WithContext(ctx).
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err)
return status.Errorf(status.Internal, "failed to remove peer from group")
}
return nil
}
// RemovePeerFromAllGroups removes a peer from all groups
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
err := s.db.WithContext(ctx).
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err)
return status.Errorf(status.Internal, "failed to remove peer from all groups")
}
return nil
@@ -1401,12 +1456,19 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
var groups []*types.Group
query := tx.
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
Joins("JOIN group_peers ON group_peers.group_id = groups.id").
Where("group_peers.peer_id = ?", peerId).
Preload(clause.Associations).
Find(&groups)
if query.Error != nil {
return nil, query.Error
}
for _, group := range groups {
group.LoadGroupPeers()
}
return groups, nil
}
@@ -1455,7 +1517,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error {
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil {
if err := s.db.Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -1583,7 +1645,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength,
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
Model(&types.Network{}).Where(accountIDCondition, accountId).Update("serial", gorm.Expr("serial + 1"))
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store")
@@ -1692,7 +1754,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
}
var group *types.Group
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
result := tx.Preload(clause.Associations).First(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewGroupNotFoundError(groupID)
@@ -1701,15 +1763,14 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
return nil, status.Errorf(status.Internal, "failed to get group from store")
}
group.LoadGroupPeers()
return group, nil
}
// GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var group types.Group
@@ -1717,16 +1778,14 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// we may need to reconsider changing the types.
query := tx.Preload(clause.Associations)
switch s.storeEngine {
case types.PostgresStoreEngine:
query = query.Order("json_array_length(peers::json) DESC")
case types.MysqlStoreEngine:
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
default:
query = query.Order("json_array_length(peers) DESC")
}
result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
result := query.
Model(&types.Group{}).
Joins("LEFT JOIN group_peers ON group_peers.group_id = groups.id").
Where("groups.account_id = ? AND groups.name = ?", accountID, groupName).
Group("groups.id").
Order("COUNT(group_peers.peer_id) DESC").
Limit(1).
First(&group)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewGroupNotFoundError(groupName)
@@ -1734,6 +1793,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
}
group.LoadGroupPeers()
return &group, nil
}
@@ -1745,7 +1807,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
}
var groups []*types.Group
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
result := tx.Preload(clause.Associations).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
@@ -1753,6 +1815,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
groupsMap := make(map[string]*types.Group)
for _, group := range groups {
group.LoadGroupPeers()
groupsMap[group.ID] = group
}
@@ -1761,17 +1824,36 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
if group == nil {
return status.Errorf(status.InvalidArgument, "group is nil")
}
group = group.Copy()
group.StoreGroupPeers()
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
return status.Errorf(status.Internal, "failed to save group to store")
}
if len(group.GroupPeers) == 0 {
if err := s.db.Where("group_id = ?", group.ID).Delete(&types.GroupPeer{}).Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group peers for group %s: %s", group.ID, err)
return status.Errorf(status.Internal, "failed to delete group peers")
}
} else {
if err := s.db.Model(&group).Association("GroupPeers").Replace(group.GroupPeers); err != nil {
return status.Errorf(status.Internal, "failed to save group peers: %s", err)
}
}
return nil
}
// DeleteGroup deletes a group from the database.
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Select(clause.Associations).
Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
@@ -1788,6 +1870,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength
// DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
Select(clause.Associations).
Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)

View File

@@ -1340,10 +1340,12 @@ func TestSqlStore_SaveGroup(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group := &types.Group{
ID: "group-id",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
ID: "group-id",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
}
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
require.NoError(t, err)
@@ -1362,16 +1364,19 @@ func TestSqlStore_SaveGroups(t *testing.T) {
groups := []*types.Group{
{
ID: "group-1",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
ID: "group-1",
AccountID: accountID,
Issued: "api",
Peers: []string{"peer1", "peer2"},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
},
{
ID: "group-2",
AccountID: accountID,
Issued: "integration",
Peers: []string{"peer3", "peer4"},
Resources: []types.Resource{},
},
}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
@@ -2059,7 +2064,7 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
log.WithContext(ctx).Debugf("creating new account")
network := types.NewNetwork()
network := types.NewNetwork(accountID)
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[nbroute.ID]*nbroute.Route)
@@ -2506,7 +2511,7 @@ func TestSqlStore_AddPeerToGroup(t *testing.T) {
require.NoError(t, err, "failed to get group")
require.Len(t, group.Peers, 0, "group should have 0 peers")
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID)
err = store.AddPeerToGroup(context.Background(), peerID, groupID)
require.NoError(t, err, "failed to add peer to group")
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
@@ -2537,7 +2542,7 @@ func TestSqlStore_AddPeerToAllGroup(t *testing.T) {
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer)
require.NoError(t, err, "failed to add peer to account")
err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID)
err = store.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
require.NoError(t, err, "failed to add peer to all group")
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
@@ -2623,7 +2628,7 @@ func TestSqlStore_GetPeerGroups(t *testing.T) {
assert.Len(t, groups, 1)
assert.Equal(t, groups[0].Name, "All")
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, "cfefqs706sqkneg59g4h")
err = store.AddPeerToGroup(context.Background(), peerID, "cfefqs706sqkneg59g4h")
require.NoError(t, err)
groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID)

View File

@@ -118,8 +118,10 @@ type Store interface {
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, peerId string, groupID string) error
RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error
RemovePeerFromAllGroups(ctx context.Context, peerID string) error
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
@@ -351,6 +353,25 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label")
},
func(db *gorm.DB) error {
return migration.MigrateJsonToTable[types.Group](ctx, db, "peers", func(id, value string) any {
return &types.GroupPeer{
GroupID: id,
PeerID: value,
}
})
},
func(db *gorm.DB) error {
return migration.MigrateEmbeddedToTable[types.Account, migration.LegacyAccountNetwork, types.Network](ctx, db, "id", func(obj migration.LegacyAccountNetwork) *types.Network {
return &types.Network{
AccountID: obj.AccountID,
Identifier: obj.Identifier,
Net: obj.Net,
Serial: obj.Serial,
Dns: obj.Dns,
}
})
},
}
}

View File

@@ -67,13 +67,13 @@ type Account struct {
IsDomainPrimaryAccount bool
SetupKeys map[string]*SetupKey `gorm:"-"`
SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"`
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
Network *Network `json:"-" gorm:"foreignKey:AccountID;references:id"`
Peers map[string]*nbpeer.Peer `gorm:"-"`
PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
Users map[string]*User `gorm:"-"`
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
Groups map[string]*Group `gorm:"-"`
GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
Routes map[route.ID]*route.Route `gorm:"-"`
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`

View File

@@ -26,7 +26,8 @@ type Group struct {
Issued string
// Peers list of the group
Peers []string `gorm:"serializer:json"`
Peers []string `gorm:"-"`
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
// Resources contains a list of resources in that group
Resources []Resource `gorm:"serializer:json"`
@@ -34,6 +35,29 @@ type Group struct {
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
}
type GroupPeer struct {
GroupID string `gorm:"primaryKey"`
PeerID string `gorm:"primaryKey"`
}
func (g *Group) LoadGroupPeers() {
g.Peers = make([]string, len(g.GroupPeers))
for i, peer := range g.GroupPeers {
g.Peers[i] = peer.PeerID
}
g.GroupPeers = []GroupPeer{}
}
func (g *Group) StoreGroupPeers() {
g.GroupPeers = make([]GroupPeer, len(g.Peers))
for i, peer := range g.Peers {
g.GroupPeers[i] = GroupPeer{
GroupID: g.ID,
PeerID: peer,
}
}
g.Peers = []string{}
}
// EventMeta returns activity event meta related to the group
func (g *Group) EventMeta() map[string]any {
return map[string]any{"name": g.Name}
@@ -46,13 +70,16 @@ func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]an
func (g *Group) Copy() *Group {
group := &Group{
ID: g.ID,
AccountID: g.AccountID,
Name: g.Name,
Issued: g.Issued,
Peers: make([]string, len(g.Peers)),
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
Resources: make([]Resource, len(g.Resources)),
IntegrationReference: g.IntegrationReference,
}
copy(group.Peers, g.Peers)
copy(group.GroupPeers, g.GroupPeers)
copy(group.Resources, g.Resources)
return group
}

View File

@@ -107,7 +107,8 @@ func ipToBytes(ip net.IP) []byte {
}
type Network struct {
Identifier string `json:"id"`
AccountID string `gorm:"primaryKey"`
Identifier string `gorm:"index"`
Net net.IPNet `gorm:"serializer:json"`
Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
@@ -117,9 +118,13 @@ type Network struct {
Mu sync.Mutex `json:"-" gorm:"-"`
}
func (*Network) TableName() string {
return "account_networks"
}
// NewNetwork creates a new Network initializing it with a Serial=0
// It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets)
func NewNetwork() *Network {
func NewNetwork(accountID string) *Network {
n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize)
sub, _ := n.Subnet(SubnetSize)
@@ -129,6 +134,7 @@ func NewNetwork() *Network {
intn := r.Intn(len(sub))
return &Network{
AccountID: accountID,
Identifier: xid.New().String(),
Net: sub[intn].IPNet,
Dns: "",
@@ -151,6 +157,7 @@ func (n *Network) CurrentSerial() uint64 {
func (n *Network) Copy() *Network {
return &Network{
AccountID: n.AccountID,
Identifier: n.Identifier,
Net: n.Net,
Dns: n.Dns,

View File

@@ -8,7 +8,7 @@ import (
)
func TestNewNetwork(t *testing.T) {
network := NewNetwork()
network := NewNetwork("accountID")
// generated net should be a subnet of a larger 100.64.0.0/10 net
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 192, 0, 0}}

View File

@@ -35,7 +35,7 @@ type SetupKey struct {
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
Key string
KeySecret string
KeySecret string `gorm:"index"`
Name string
Type SetupKeyType
CreatedAt time.Time