mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 08:54:11 -04:00
Compare commits
35 Commits
snyk-fix-9
...
feature/ne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
261d1e094a | ||
|
|
bcab5cbbee | ||
|
|
3931958499 | ||
|
|
914e58ac75 | ||
|
|
7baeea3d9d | ||
|
|
64e618f1ad | ||
|
|
5802dcaf80 | ||
|
|
b329397d06 | ||
|
|
76b180a741 | ||
|
|
8db9065ed9 | ||
|
|
1f79fc0728 | ||
|
|
bdd0b1cf02 | ||
|
|
e6ac248aee | ||
|
|
376394f7f9 | ||
|
|
542dbdb41c | ||
|
|
982b9604ee | ||
|
|
f2990e2fbc | ||
|
|
dfb47d5545 | ||
|
|
8e0b8f20a2 | ||
|
|
8a42528664 | ||
|
|
a8cba921e1 | ||
|
|
fee36b0663 | ||
|
|
dfad334780 | ||
|
|
d25da87957 | ||
|
|
13213d954d | ||
|
|
6fb61c7cf5 | ||
|
|
459db2ba4f | ||
|
|
e78b7dd058 | ||
|
|
7132642e4c | ||
|
|
22a944b157 | ||
|
|
005937ae77 | ||
|
|
5fab2d019a | ||
|
|
36155f8de1 | ||
|
|
d06831dd2f | ||
|
|
e23282b92c |
@@ -1688,7 +1688,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := types.NewNetwork()
|
||||
network := types.NewNetwork(accountID)
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*types.User)
|
||||
routes := make(map[route.ID]*route.Route)
|
||||
@@ -1792,7 +1792,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
||||
continue
|
||||
}
|
||||
|
||||
network := types.NewNetwork()
|
||||
network := types.NewNetwork(accountId)
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*types.User)
|
||||
routes := make(map[route.ID]*route.Route)
|
||||
|
||||
@@ -1240,9 +1240,10 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
AccountID: account.Id,
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
@@ -1672,9 +1673,10 @@ func TestAccount_Copy(t *testing.T) {
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
Peers: []string{"peer1"},
|
||||
Resources: []types.Resource{},
|
||||
ID: "group1",
|
||||
Peers: []string{"peer1"},
|
||||
Resources: []types.Resource{},
|
||||
GroupPeers: []types.GroupPeer{},
|
||||
},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
@@ -2616,6 +2618,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
|
||||
@@ -265,20 +265,10 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.AddPeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -288,7 +278,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
return transaction.AddPeerToGroup(ctx, peerID, groupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -347,20 +337,10 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updated := group.RemovePeer(peerID); !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -370,7 +350,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
return transaction.RemovePeerFromGroup(ctx, peerID, groupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -2,14 +2,19 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -18,8 +23,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
peer2 "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -733,3 +740,259 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_AddPeerToGroup(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
acc, err := createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
err = manager.Store.AddPeerToGroup(context.Background(), strconv.Itoa(i), acc.GroupsG[0].ID)
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
|
||||
}
|
||||
|
||||
func Test_AddPeerToAll(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
err = manager.Store.AddPeerToAllGroup(context.Background(), accountID, strconv.Itoa(i))
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
|
||||
}
|
||||
|
||||
func Test_AddPeerAndAddToAll(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
peer := &peer2.Peer{
|
||||
ID: strconv.Itoa(i),
|
||||
AccountID: accountID,
|
||||
DNSLabel: "peer" + strconv.Itoa(i),
|
||||
IP: uint32ToIP(uint32(i)),
|
||||
}
|
||||
|
||||
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
|
||||
err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
}
|
||||
err = transaction.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("AddPeer failed for peer %d: %v", i, err)
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, len(maps.Values(account.Groups)[0].Peers), "Expected %d peers in group %s in account %s, got %d", totalPeers, maps.Values(account.Groups)[0].Name, accountID, len(account.Peers))
|
||||
assert.Equal(t, totalPeers, len(account.Peers), "Expected %d peers in account %s, got %d", totalPeers, accountID, len(account.Peers))
|
||||
}
|
||||
|
||||
func uint32ToIP(n uint32) net.IP {
|
||||
ip := make(net.IP, 4)
|
||||
binary.BigEndian.PutUint32(ip, n)
|
||||
return ip
|
||||
}
|
||||
|
||||
func Test_IncrementNetworkSerial(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
accountID := "testaccount"
|
||||
userID := "testuser"
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-start
|
||||
|
||||
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
|
||||
err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("AddPeer failed for peer %d: %v", i, err)
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
startTime := time.Now()
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
t.Logf("time since start: %s", time.Since(startTime))
|
||||
|
||||
for err := range errs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, totalPeers, int(account.Network.Serial), "Expected %d serial increases in account %s, got %d", totalPeers, accountID, account.Network.Serial)
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
handler := initAccountsTestData(t, &types.Account{
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
Network: types.NewNetwork(),
|
||||
Network: types.NewNetwork(accountID),
|
||||
Users: map[string]*types.User{
|
||||
adminUser.Id: adminUser,
|
||||
},
|
||||
|
||||
@@ -10,13 +10,25 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type LegacyAccountNetwork struct {
|
||||
AccountID string `gorm:"column:id"`
|
||||
Identifier string `gorm:"column:network_identifier"`
|
||||
Net net.IPNet `gorm:"column:network_net;serializer:json"`
|
||||
Dns string `gorm:"column:network_dns"`
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
// Used to synchronize state to the client apps.
|
||||
Serial uint64 `gorm:"column:network_serial"`
|
||||
}
|
||||
|
||||
func GetColumnName(db *gorm.DB, column string) string {
|
||||
if db.Name() == "mysql" {
|
||||
return fmt.Sprintf("`%s`", column)
|
||||
@@ -39,6 +51,11 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f
|
||||
return nil
|
||||
}
|
||||
|
||||
if !db.Migrator().HasColumn(&model, oldColumnName) {
|
||||
log.WithContext(ctx).Debugf("Column for %T does not exist, no migration needed", oldColumnName)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(model)
|
||||
if err != nil {
|
||||
@@ -412,3 +429,149 @@ func CreateIndexIfNotExists[T any](ctx context.Context, db *gorm.DB, indexName s
|
||||
log.WithContext(ctx).Infof("successfully created index %s on table %s", indexName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName string, mapperFunc func(id string, value string) any) error {
|
||||
var model T
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(&model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
|
||||
if !db.Migrator().HasColumn(&model, columnName) {
|
||||
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no migration needed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
var rows []map[string]any
|
||||
if err := tx.Table(tableName).Select("id", columnName).Find(&rows).Error; err != nil {
|
||||
return fmt.Errorf("find rows: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
jsonValue, ok := row[columnName].(string)
|
||||
if !ok || jsonValue == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var data []string
|
||||
if err := json.Unmarshal([]byte(jsonValue), &data); err != nil {
|
||||
return fmt.Errorf("unmarshal json: %w", err)
|
||||
}
|
||||
|
||||
for _, value := range data {
|
||||
if err := tx.Clauses(clause.OnConflict{
|
||||
DoNothing: true, // this needs to be removed when the cleanup is enabled
|
||||
}).Create(
|
||||
mapperFunc(row["id"].(string), value),
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("failed to insert id %v: %w", row["id"], err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Todo: Enable this after we are sure that every thing works as expected and we do not need to rollback anymore
|
||||
// if err := tx.Migrator().DropColumn(&model, columnName); err != nil {
|
||||
// return fmt.Errorf("drop column %s: %w", columnName, err)
|
||||
// }
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateEmbeddedToTable[T any, S any, U any](ctx context.Context, db *gorm.DB, pkey string, mapperFunc func(obj S) *U) error {
|
||||
var model T
|
||||
|
||||
log.WithContext(ctx).Debugf("Migrating embedded fields from %T to separate table", model)
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(&model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
var legacyRows []S
|
||||
if err := tx.Table(tableName).Find(&legacyRows).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("Failed to read legacy accounts: %v", err)
|
||||
return fmt.Errorf("failed to read legacy accounts: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range legacyRows {
|
||||
if err := tx.Clauses(clause.OnConflict{
|
||||
DoNothing: true, // this needs to be removed when the cleanup is enabled
|
||||
}).Create(
|
||||
mapperFunc(row),
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("failed to insert id %v: %w", row, err)
|
||||
}
|
||||
}
|
||||
|
||||
// cols, err := getColumnNamesFromStruct(new(S))
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to extract column names: %w", err)
|
||||
// }
|
||||
|
||||
// for _, col := range cols {
|
||||
// if col == pkey {
|
||||
// continue
|
||||
// }
|
||||
// if err := tx.Migrator().DropColumn(&model, col); err != nil {
|
||||
// return fmt.Errorf("failed to drop column %s: %w", col, err)
|
||||
// }
|
||||
// }
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("Migration of embedded fields %T from table %s into seperte table completed", new(S), tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getColumnNamesFromStruct[T any](model T) ([]string, error) {
|
||||
val := reflect.TypeOf(model)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
var cols []string
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := val.Field(i)
|
||||
if field.Name == "ID" {
|
||||
continue // skip primary key
|
||||
}
|
||||
tag := field.Tag.Get("gorm")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
// Look for gorm:"column:..."
|
||||
for _, part := range strings.Split(tag, ";") {
|
||||
if strings.HasPrefix(part, "column:") {
|
||||
cols = append(cols, strings.TrimPrefix(part, "column:"))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
@@ -360,25 +360,20 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthUpdate, accountID, peerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get peer groups: %w", err)
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
group.RemovePeer(peerID)
|
||||
err = transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save group: %w", err)
|
||||
}
|
||||
if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
|
||||
return fmt.Errorf("failed to remove peer from groups: %w", err)
|
||||
}
|
||||
|
||||
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete peer: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -477,7 +472,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
var newPeer *nbpeer.Peer
|
||||
var updateAccountPeers bool
|
||||
|
||||
var setupKeyID string
|
||||
var setupKeyName string
|
||||
@@ -594,12 +588,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
newPeer.DNSLabel = freeLabel
|
||||
newPeer.IP = freeIP
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
// unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
// defer func() {
|
||||
// if unlock != nil {
|
||||
// unlock()
|
||||
// }
|
||||
// }()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer)
|
||||
@@ -607,20 +601,20 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return err
|
||||
}
|
||||
|
||||
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||
}
|
||||
|
||||
if len(groupsToAdd) > 0 {
|
||||
for _, g := range groupsToAdd {
|
||||
err = transaction.AddPeerToGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID, g)
|
||||
err = transaction.AddPeerToGroup(ctx, newPeer.ID, g)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed adding peer to All group: %w", err)
|
||||
}
|
||||
|
||||
if addedByUser {
|
||||
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
|
||||
if err != nil {
|
||||
@@ -652,14 +646,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
unlock()
|
||||
unlock = nil
|
||||
// unlock()
|
||||
// unlock = nil
|
||||
break
|
||||
}
|
||||
|
||||
if isUniqueConstraintError(err) {
|
||||
unlock()
|
||||
unlock = nil
|
||||
// unlock()
|
||||
// unlock = nil
|
||||
log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err)
|
||||
continue
|
||||
}
|
||||
@@ -670,11 +664,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
|
||||
}
|
||||
|
||||
updateAccountPeers, err = isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
if newPeer == nil {
|
||||
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||
}
|
||||
@@ -687,9 +676,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
if updateAccountPeers {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
}
|
||||
@@ -1019,7 +1006,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
}()
|
||||
|
||||
if isRequiresApproval {
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -1268,17 +1255,19 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
lock := mu.(*sync.Mutex)
|
||||
|
||||
if !lock.TryLock() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
|
||||
lock.Unlock()
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
lock := mu.(*sync.Mutex)
|
||||
|
||||
if !lock.TryLock() {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load()))
|
||||
lock.Unlock()
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}()
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -1459,6 +1459,10 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
engine := os.Getenv("NETBIRD_STORE_ENGINE")
|
||||
if engine == "sqlite" || engine == "" {
|
||||
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
@@ -1764,7 +1768,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg) //
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -2144,6 +2148,8 @@ func Test_IsUniqueConstraintError(t *testing.T) {
|
||||
|
||||
func Test_AddPeer(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
|
||||
t.Setenv("NB_GET_ACCOUNT_BUFFER_INTERVAL", "300ms")
|
||||
t.Setenv("NB_PEER_UPDATE_BUFFER_INTERVAL", "300ms")
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -2155,7 +2161,7 @@ func Test_AddPeer(t *testing.T) {
|
||||
|
||||
_, err = createAccount(manager, accountID, userID, "domain.com")
|
||||
if err != nil {
|
||||
t.Fatal("error creating account")
|
||||
t.Fatalf("error creating account: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2165,22 +2171,21 @@ func Test_AddPeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
const totalPeers = 300 // totalPeers / differentHostnames should be less than 10 (due to concurrent retries)
|
||||
const differentHostnames = 50
|
||||
const totalPeers = 300
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, totalPeers+differentHostnames)
|
||||
errs := make(chan error, totalPeers)
|
||||
start := make(chan struct{})
|
||||
for i := 0; i < totalPeers; i++ {
|
||||
wg.Add(1)
|
||||
hostNameID := i % differentHostnames
|
||||
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
Key: "key" + strconv.Itoa(i),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(hostNameID), GoOS: "linux"},
|
||||
AccountID: accountID,
|
||||
Key: "key" + strconv.Itoa(i),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "peer" + strconv.Itoa(i), GoOS: "linux"},
|
||||
}
|
||||
|
||||
<-start
|
||||
|
||||
@@ -96,7 +96,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
||||
}
|
||||
err = db.AutoMigrate(
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.Network{},
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||
@@ -186,6 +186,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
|
||||
generateAccountSQLTypes(account)
|
||||
|
||||
for _, group := range account.GroupsG {
|
||||
group.StoreGroupPeers()
|
||||
}
|
||||
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
|
||||
if result.Error != nil {
|
||||
@@ -247,7 +251,7 @@ func generateAccountSQLTypes(account *types.Account) {
|
||||
|
||||
for id, group := range account.Groups {
|
||||
group.ID = id
|
||||
account.GroupsG = append(account.GroupsG, *group)
|
||||
account.GroupsG = append(account.GroupsG, group)
|
||||
}
|
||||
|
||||
for id, route := range account.Routes {
|
||||
@@ -455,19 +459,40 @@ func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength,
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.
|
||||
Clauses(
|
||||
clause.Locking{Strength: string(lockStrength)},
|
||||
clause.OnConflict{
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
},
|
||||
).
|
||||
Create(&groups)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
||||
for _, g := range groups {
|
||||
g.StoreGroupPeers()
|
||||
}
|
||||
return nil
|
||||
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.
|
||||
Clauses(
|
||||
clause.Locking{Strength: string(lockStrength)},
|
||||
clause.OnConflict{
|
||||
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
|
||||
UpdateAll: true,
|
||||
},
|
||||
).
|
||||
Create(&groups)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save groups to store")
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
if len(g.GroupPeers) == 0 {
|
||||
if err := tx.Where("group_id = ?", g.ID).Delete(&types.GroupPeer{}).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete group peers for group %s: %s", g.ID, err)
|
||||
return status.Errorf(status.Internal, "failed to delete group peers")
|
||||
}
|
||||
} else {
|
||||
if err := tx.Model(&g).Association("GroupPeers").Replace(g.GroupPeers); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to save group peers: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
||||
@@ -646,7 +671,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
|
||||
}
|
||||
|
||||
var groups []*types.Group
|
||||
result := tx.Find(&groups, accountIDCondition, accountID)
|
||||
result := tx.Preload(clause.Associations).Find(&groups, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
@@ -655,6 +680,10 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
|
||||
return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
g.LoadGroupPeers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
@@ -669,6 +698,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
|
||||
likePattern := `%"ID":"` + resourceID + `"%`
|
||||
|
||||
result := tx.
|
||||
Preload(clause.Associations).
|
||||
Where("resources LIKE ?", likePattern).
|
||||
Find(&groups)
|
||||
|
||||
@@ -679,6 +709,10 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
g.LoadGroupPeers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
@@ -738,8 +772,9 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
}()
|
||||
|
||||
var account types.Account
|
||||
result := s.db.Model(&account).
|
||||
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
||||
result := s.db.Session(&gorm.Session{Logger: logger.Default.LogMode(logger.Info)}).Model(&account).
|
||||
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
||||
Preload("GroupsG.GroupPeers"). // have to be specifies as this is nester reference
|
||||
Preload(clause.Associations).
|
||||
First(&account, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
@@ -753,7 +788,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
|
||||
for i, policy := range account.Policies {
|
||||
var rules []*types.PolicyRule
|
||||
err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
err := s.db.Session(&gorm.Session{Logger: logger.Default.LogMode(logger.Info)}).Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "rule not found")
|
||||
}
|
||||
@@ -784,6 +819,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
|
||||
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
|
||||
for _, group := range account.GroupsG {
|
||||
group.LoadGroupPeers()
|
||||
account.Groups[group.ID] = group.Copy()
|
||||
}
|
||||
account.GroupsG = nil
|
||||
@@ -998,14 +1034,14 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var accountNetwork types.AccountNetwork
|
||||
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
accountNetwork := types.Network{}
|
||||
if err := tx.Where(accountIDCondition, accountID).First(&accountNetwork).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
|
||||
}
|
||||
return accountNetwork.Network, nil
|
||||
return &accountNetwork, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
@@ -1285,55 +1321,74 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
}
|
||||
|
||||
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
|
||||
var group types.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&group, "account_id = ? AND name = ?", accountID, "All")
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
||||
}
|
||||
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
var groupID string
|
||||
_ = s.db.Model(types.Group{}).
|
||||
Select("id").
|
||||
Where("account_id = ? AND name = ?", accountID, "All").
|
||||
Limit(1).
|
||||
Scan(&groupID)
|
||||
|
||||
if groupID == "" {
|
||||
return status.Errorf(status.NotFound, "group 'All' not found for account %s", accountID)
|
||||
}
|
||||
|
||||
for _, existingPeerID := range group.Peers {
|
||||
if existingPeerID == peerID {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
err := s.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(&types.GroupPeer{
|
||||
GroupID: groupID,
|
||||
PeerID: peerID,
|
||||
}).Error
|
||||
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error {
|
||||
var group types.Group
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID).
|
||||
First(&group)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.NewGroupNotFoundError(groupID)
|
||||
}
|
||||
|
||||
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
||||
// AddPeerToGroup adds a peer to a group
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, peerID string, groupID string) error {
|
||||
peer := &types.GroupPeer{
|
||||
GroupID: groupID,
|
||||
PeerID: peerID,
|
||||
}
|
||||
|
||||
for _, existingPeerID := range group.Peers {
|
||||
if existingPeerID == peerId {
|
||||
return nil
|
||||
}
|
||||
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(peer).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to add peer %s to group %s: %v", peerID, groupID, err)
|
||||
return status.Errorf(status.Internal, "failed to add peer to group")
|
||||
}
|
||||
|
||||
group.Peers = append(group.Peers, peerId)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue updating group: %s", err)
|
||||
// RemovePeerFromGroup removes a peer from a group
|
||||
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
|
||||
err := s.db.WithContext(ctx).
|
||||
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err)
|
||||
return status.Errorf(status.Internal, "failed to remove peer from group")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeerFromAllGroups removes a peer from all groups
|
||||
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
|
||||
err := s.db.WithContext(ctx).
|
||||
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err)
|
||||
return status.Errorf(status.Internal, "failed to remove peer from all groups")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1401,12 +1456,19 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
var groups []*types.Group
|
||||
query := tx.
|
||||
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
|
||||
Joins("JOIN group_peers ON group_peers.group_id = groups.id").
|
||||
Where("group_peers.peer_id = ?", peerId).
|
||||
Preload(clause.Associations).
|
||||
Find(&groups)
|
||||
|
||||
if query.Error != nil {
|
||||
return nil, query.Error
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
group.LoadGroupPeers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
@@ -1455,7 +1517,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error {
|
||||
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil {
|
||||
if err := s.db.Create(peer).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
||||
}
|
||||
|
||||
@@ -1583,7 +1645,7 @@ func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength,
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
Model(&types.Network{}).Where(accountIDCondition, accountId).Update("serial", gorm.Expr("serial + 1"))
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
||||
@@ -1692,7 +1754,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
var group *types.Group
|
||||
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
|
||||
result := tx.Preload(clause.Associations).First(&group, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewGroupNotFoundError(groupID)
|
||||
@@ -1701,15 +1763,14 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
||||
return nil, status.Errorf(status.Internal, "failed to get group from store")
|
||||
}
|
||||
|
||||
group.LoadGroupPeers()
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// GetGroupByName retrieves a group by name and account ID.
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var group types.Group
|
||||
|
||||
@@ -1717,16 +1778,14 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
// we may need to reconsider changing the types.
|
||||
query := tx.Preload(clause.Associations)
|
||||
|
||||
switch s.storeEngine {
|
||||
case types.PostgresStoreEngine:
|
||||
query = query.Order("json_array_length(peers::json) DESC")
|
||||
case types.MysqlStoreEngine:
|
||||
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
|
||||
default:
|
||||
query = query.Order("json_array_length(peers) DESC")
|
||||
}
|
||||
|
||||
result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
|
||||
result := query.
|
||||
Model(&types.Group{}).
|
||||
Joins("LEFT JOIN group_peers ON group_peers.group_id = groups.id").
|
||||
Where("groups.account_id = ? AND groups.name = ?", accountID, groupName).
|
||||
Group("groups.id").
|
||||
Order("COUNT(group_peers.peer_id) DESC").
|
||||
Limit(1).
|
||||
First(&group)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewGroupNotFoundError(groupName)
|
||||
@@ -1734,6 +1793,9 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get group by name from store")
|
||||
}
|
||||
|
||||
group.LoadGroupPeers()
|
||||
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
@@ -1745,7 +1807,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
}
|
||||
|
||||
var groups []*types.Group
|
||||
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
result := tx.Preload(clause.Associations).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
|
||||
@@ -1753,6 +1815,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
groupsMap := make(map[string]*types.Group)
|
||||
for _, group := range groups {
|
||||
group.LoadGroupPeers()
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
@@ -1761,17 +1824,36 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
// SaveGroup saves a group to the store.
|
||||
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
|
||||
if group == nil {
|
||||
return status.Errorf(status.InvalidArgument, "group is nil")
|
||||
}
|
||||
|
||||
group = group.Copy()
|
||||
group.StoreGroupPeers()
|
||||
|
||||
if err := s.db.Omit(clause.Associations).Save(group).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to save group to store")
|
||||
}
|
||||
|
||||
if len(group.GroupPeers) == 0 {
|
||||
if err := s.db.Where("group_id = ?", group.ID).Delete(&types.GroupPeer{}).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete group peers for group %s: %s", group.ID, err)
|
||||
return status.Errorf(status.Internal, "failed to delete group peers")
|
||||
}
|
||||
} else {
|
||||
if err := s.db.Model(&group).Association("GroupPeers").Replace(group.GroupPeers); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to save group peers: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroup deletes a group from the database.
|
||||
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Select(clause.Associations).
|
||||
Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
|
||||
@@ -1788,6 +1870,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength
|
||||
// DeleteGroups deletes groups from the database.
|
||||
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
|
||||
Select(clause.Associations).
|
||||
Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
|
||||
|
||||
@@ -1340,10 +1340,12 @@ func TestSqlStore_SaveGroup(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
group := &types.Group{
|
||||
ID: "group-id",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
ID: "group-id",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
Resources: []types.Resource{},
|
||||
GroupPeers: []types.GroupPeer{},
|
||||
}
|
||||
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
require.NoError(t, err)
|
||||
@@ -1362,16 +1364,19 @@ func TestSqlStore_SaveGroups(t *testing.T) {
|
||||
|
||||
groups := []*types.Group{
|
||||
{
|
||||
ID: "group-1",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
ID: "group-1",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
Peers: []string{"peer1", "peer2"},
|
||||
Resources: []types.Resource{},
|
||||
GroupPeers: []types.GroupPeer{},
|
||||
},
|
||||
{
|
||||
ID: "group-2",
|
||||
AccountID: accountID,
|
||||
Issued: "integration",
|
||||
Peers: []string{"peer3", "peer4"},
|
||||
Resources: []types.Resource{},
|
||||
},
|
||||
}
|
||||
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, accountID, groups)
|
||||
@@ -2059,7 +2064,7 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := types.NewNetwork()
|
||||
network := types.NewNetwork(accountID)
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*types.User)
|
||||
routes := make(map[nbroute.ID]*nbroute.Route)
|
||||
@@ -2506,7 +2511,7 @@ func TestSqlStore_AddPeerToGroup(t *testing.T) {
|
||||
require.NoError(t, err, "failed to get group")
|
||||
require.Len(t, group.Peers, 0, "group should have 0 peers")
|
||||
|
||||
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID)
|
||||
err = store.AddPeerToGroup(context.Background(), peerID, groupID)
|
||||
require.NoError(t, err, "failed to add peer to group")
|
||||
|
||||
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
@@ -2537,7 +2542,7 @@ func TestSqlStore_AddPeerToAllGroup(t *testing.T) {
|
||||
err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer)
|
||||
require.NoError(t, err, "failed to add peer to account")
|
||||
|
||||
err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID)
|
||||
err = store.AddPeerToAllGroup(context.Background(), accountID, peer.ID)
|
||||
require.NoError(t, err, "failed to add peer to all group")
|
||||
|
||||
group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
@@ -2623,7 +2628,7 @@ func TestSqlStore_GetPeerGroups(t *testing.T) {
|
||||
assert.Len(t, groups, 1)
|
||||
assert.Equal(t, groups[0].Name, "All")
|
||||
|
||||
err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, "cfefqs706sqkneg59g4h")
|
||||
err = store.AddPeerToGroup(context.Background(), peerID, "cfefqs706sqkneg59g4h")
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID)
|
||||
|
||||
@@ -118,8 +118,10 @@ type Store interface {
|
||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, peerId string, groupID string) error
|
||||
RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error
|
||||
RemovePeerFromAllGroups(ctx context.Context, peerID string) error
|
||||
GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error)
|
||||
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
|
||||
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
|
||||
@@ -351,6 +353,25 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_account_dnslabel", "account_id", "dns_label")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateJsonToTable[types.Group](ctx, db, "peers", func(id, value string) any {
|
||||
return &types.GroupPeer{
|
||||
GroupID: id,
|
||||
PeerID: value,
|
||||
}
|
||||
})
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateEmbeddedToTable[types.Account, migration.LegacyAccountNetwork, types.Network](ctx, db, "id", func(obj migration.LegacyAccountNetwork) *types.Network {
|
||||
return &types.Network{
|
||||
AccountID: obj.AccountID,
|
||||
Identifier: obj.Identifier,
|
||||
Net: obj.Net,
|
||||
Serial: obj.Serial,
|
||||
Dns: obj.Dns,
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -67,13 +67,13 @@ type Account struct {
|
||||
IsDomainPrimaryAccount bool
|
||||
SetupKeys map[string]*SetupKey `gorm:"-"`
|
||||
SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
Network *Network `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Peers map[string]*nbpeer.Peer `gorm:"-"`
|
||||
PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Users map[string]*User `gorm:"-"`
|
||||
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Groups map[string]*Group `gorm:"-"`
|
||||
GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
|
||||
Routes map[route.ID]*route.Route `gorm:"-"`
|
||||
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
|
||||
@@ -26,7 +26,8 @@ type Group struct {
|
||||
Issued string
|
||||
|
||||
// Peers list of the group
|
||||
Peers []string `gorm:"serializer:json"`
|
||||
Peers []string `gorm:"-"`
|
||||
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
|
||||
// Resources contains a list of resources in that group
|
||||
Resources []Resource `gorm:"serializer:json"`
|
||||
@@ -34,6 +35,29 @@ type Group struct {
|
||||
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
|
||||
}
|
||||
|
||||
type GroupPeer struct {
|
||||
GroupID string `gorm:"primaryKey"`
|
||||
PeerID string `gorm:"primaryKey"`
|
||||
}
|
||||
|
||||
func (g *Group) LoadGroupPeers() {
|
||||
g.Peers = make([]string, len(g.GroupPeers))
|
||||
for i, peer := range g.GroupPeers {
|
||||
g.Peers[i] = peer.PeerID
|
||||
}
|
||||
g.GroupPeers = []GroupPeer{}
|
||||
}
|
||||
func (g *Group) StoreGroupPeers() {
|
||||
g.GroupPeers = make([]GroupPeer, len(g.Peers))
|
||||
for i, peer := range g.Peers {
|
||||
g.GroupPeers[i] = GroupPeer{
|
||||
GroupID: g.ID,
|
||||
PeerID: peer,
|
||||
}
|
||||
}
|
||||
g.Peers = []string{}
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to the group
|
||||
func (g *Group) EventMeta() map[string]any {
|
||||
return map[string]any{"name": g.Name}
|
||||
@@ -46,13 +70,16 @@ func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]an
|
||||
func (g *Group) Copy() *Group {
|
||||
group := &Group{
|
||||
ID: g.ID,
|
||||
AccountID: g.AccountID,
|
||||
Name: g.Name,
|
||||
Issued: g.Issued,
|
||||
Peers: make([]string, len(g.Peers)),
|
||||
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
|
||||
Resources: make([]Resource, len(g.Resources)),
|
||||
IntegrationReference: g.IntegrationReference,
|
||||
}
|
||||
copy(group.Peers, g.Peers)
|
||||
copy(group.GroupPeers, g.GroupPeers)
|
||||
copy(group.Resources, g.Resources)
|
||||
return group
|
||||
}
|
||||
|
||||
@@ -107,7 +107,8 @@ func ipToBytes(ip net.IP) []byte {
|
||||
}
|
||||
|
||||
type Network struct {
|
||||
Identifier string `json:"id"`
|
||||
AccountID string `gorm:"primaryKey"`
|
||||
Identifier string `gorm:"index"`
|
||||
Net net.IPNet `gorm:"serializer:json"`
|
||||
Dns string
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
@@ -117,9 +118,13 @@ type Network struct {
|
||||
Mu sync.Mutex `json:"-" gorm:"-"`
|
||||
}
|
||||
|
||||
func (*Network) TableName() string {
|
||||
return "account_networks"
|
||||
}
|
||||
|
||||
// NewNetwork creates a new Network initializing it with a Serial=0
|
||||
// It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets)
|
||||
func NewNetwork() *Network {
|
||||
func NewNetwork(accountID string) *Network {
|
||||
|
||||
n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize)
|
||||
sub, _ := n.Subnet(SubnetSize)
|
||||
@@ -129,6 +134,7 @@ func NewNetwork() *Network {
|
||||
intn := r.Intn(len(sub))
|
||||
|
||||
return &Network{
|
||||
AccountID: accountID,
|
||||
Identifier: xid.New().String(),
|
||||
Net: sub[intn].IPNet,
|
||||
Dns: "",
|
||||
@@ -151,6 +157,7 @@ func (n *Network) CurrentSerial() uint64 {
|
||||
|
||||
func (n *Network) Copy() *Network {
|
||||
return &Network{
|
||||
AccountID: n.AccountID,
|
||||
Identifier: n.Identifier,
|
||||
Net: n.Net,
|
||||
Dns: n.Dns,
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func TestNewNetwork(t *testing.T) {
|
||||
network := NewNetwork()
|
||||
network := NewNetwork("accountID")
|
||||
|
||||
// generated net should be a subnet of a larger 100.64.0.0/10 net
|
||||
ipNet := net.IPNet{IP: net.ParseIP("100.64.0.0"), Mask: net.IPMask{255, 192, 0, 0}}
|
||||
|
||||
@@ -35,7 +35,7 @@ type SetupKey struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
Key string
|
||||
KeySecret string
|
||||
KeySecret string `gorm:"index"`
|
||||
Name string
|
||||
Type SetupKeyType
|
||||
CreatedAt time.Time
|
||||
|
||||
Reference in New Issue
Block a user