Compare commits

...

19 Commits

Author SHA1 Message Date
pascal
a88dd8a692 fix benchmark 2026-01-09 16:09:42 +01:00
pascal
2b00a429d7 fix comment 2026-01-09 14:20:22 +01:00
pascal
29a31001bd properly copy group users 2026-01-09 13:55:17 +01:00
pascal
023d85f42a Merge branch 'main' into feature/migrate-auto-groups-to-table
# Conflicts:
#	management/server/migration/migration.go
#	management/server/store/store.go
2026-01-09 13:46:30 +01:00
pascal
3434760526 update service user tests 2026-01-09 02:10:13 +01:00
pascal
e33e5673c5 do not use user group add 2026-01-09 01:55:18 +01:00
pascal
71d98940dc fix group test setup 2026-01-09 01:42:29 +01:00
pascal
5f7a6b839b fix group test setup 2026-01-09 01:14:03 +01:00
pascal
1481dbcdd7 fix group test setup 2026-01-09 00:53:03 +01:00
pascal
7956f676a4 fix group test setup 2026-01-09 00:49:21 +01:00
pascal
ddcf9f820b fix StoreAutoGroups len alloc 2026-01-09 00:30:55 +01:00
pascal
475ce092c8 first dave user, then associations 2026-01-09 00:28:28 +01:00
pascal
80c49c268f add debugging store call 2026-01-08 23:53:54 +01:00
pascal
cdfe0f3d41 fix get account equal 2026-01-08 23:49:12 +01:00
pascal
794976263e fix account copy 2026-01-08 23:05:43 +01:00
pascal
77ea4b7444 fix saveAccount and gorm order 2026-01-08 22:47:32 +01:00
pascal
9fd34718a6 remove emptying users 2026-01-08 22:21:46 +01:00
pascal
ea37d4b768 fix tests 2026-01-08 20:32:24 +01:00
pascal
f7ee019f26 migrate auto groups to different table 2026-01-08 15:45:48 +01:00
13 changed files with 554 additions and 92 deletions

View File

@@ -1402,9 +1402,6 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return fmt.Errorf("error saving groups: %w", err)
}
addNewGroups = util.Difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups)
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, user); err != nil {
return fmt.Errorf("error saving user: %w", err)

View File

@@ -918,6 +918,7 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
}
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
b.Setenv("NETBIRD_STORE_ENGINE", "postgres")
claims := auth.UserAuth{
Domain: "example.com",
UserId: "pvt-domain-user",
@@ -945,6 +946,18 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
b.Fatal(err)
}
a, err := am.Store.GetAccount(context.Background(), id)
if err != nil {
b.Fatal(err)
}
a.Groups = genGroups()
err = am.Store.SaveAccount(context.Background(), a)
if err != nil {
b.Fatal(err)
}
users := genUsers("priv", 100)
acc, err := am.Store.GetAccount(context.Background(), id)
@@ -1005,6 +1018,41 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
}
func genGroups() map[string]*types.Group {
return map[string]*types.Group{
"one": {
Name: "one",
},
"two": {
Name: "two",
},
"three": {
Name: "three",
},
"four": {
Name: "four",
},
"five": {
Name: "five",
},
"six": {
Name: "six",
},
"seven": {
Name: "seven",
},
"eight": {
Name: "eight",
},
"nine": {
Name: "nine",
},
"ten": {
Name: "ten",
},
}
}
func genUsers(p string, n int) map[string]*types.User {
users := map[string]*types.User{}
now := time.Now()
@@ -1723,6 +1771,13 @@ func TestAccount_Copy(t *testing.T) {
Id: "user1",
Role: types.UserRoleAdmin,
AutoGroups: []string{"group1"},
Groups: []*types.GroupUser{
{
AccountID: "account1",
UserID: "user1",
GroupID: "group1",
},
},
PATs: map[string]*types.PersonalAccessToken{
"pat1": {
ID: "pat1",
@@ -1742,6 +1797,13 @@ func TestAccount_Copy(t *testing.T) {
Peers: []string{"peer1"},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
GroupUsers: []types.GroupUser{
{
AccountID: "account1",
UserID: "user1",
GroupID: "group1",
},
},
},
},
Policies: []*types.Policy{

View File

@@ -380,13 +380,6 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
AutoGroups: []string{groupForUsers.ID},
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, "", "", false)
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
account.Policies = append(account.Policies, policy)
account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, nil, err
@@ -400,6 +393,23 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
account, err = am.Store.GetAccount(context.Background(), accountID)
if err != nil {
return nil, nil, err
}
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
account.Policies = append(account.Policies, policy)
account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, nil, err
}
acc, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
return nil, nil, err

View File

@@ -539,3 +539,103 @@ func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error {
return nil
}
// CleanupOrphanedIDs removes non-existent IDs from the JSON array column.
// T is the type of the model that contains the list.
// This migration cleans up the lists field by removing IDs that no longer exist in the target table.
func CleanupOrphanedIDs[T, S any](ctx context.Context, db *gorm.DB, columnName string) error {
var sourceModel T
var fkModel S
if !db.Migrator().HasTable(&sourceModel) {
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", sourceModel)
return nil
}
if !db.Migrator().HasTable(&fkModel) {
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", fkModel)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&sourceModel)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
if !db.Migrator().HasColumn(&sourceModel, columnName) {
log.WithContext(ctx).Debugf("Column %s does not exist in table %s, no migration needed", columnName, tableName)
return nil
}
if err := db.Transaction(func(tx *gorm.DB) error {
var rows []map[string]any
if err := tx.Table(tableName).Select("id", columnName).Find(&rows).Error; err != nil {
return fmt.Errorf("find rows: %w", err)
}
// Get all valid IDs from the fk table
var validIDs []string
if err := tx.Model(fkModel).Select("id").Pluck("id", &validIDs).Error; err != nil {
return fmt.Errorf("fetch valid group IDs: %w", err)
}
validIDMap := make(map[string]bool, len(validIDs))
for _, id := range validIDs {
validIDMap[id] = true
}
updatedCount := 0
for _, row := range rows {
jsonValue, ok := row[columnName].(string)
if !ok || jsonValue == "" || jsonValue == "null" {
continue
}
var list []string
if err := json.Unmarshal([]byte(jsonValue), &list); err != nil {
log.WithContext(ctx).Warnf("Failed to unmarshal %s for id %v: %v", columnName, row["id"], err)
continue
}
if len(list) == 0 {
continue
}
// Filter out non-existent IDs
cleanedList := make([]string, 0, len(list))
for _, groupID := range list {
if validIDMap[groupID] {
cleanedList = append(cleanedList, groupID)
}
}
// Only update if there were orphaned ids removed
if len(cleanedList) != len(list) {
cleanedJSON, err := json.Marshal(cleanedList)
if err != nil {
return fmt.Errorf("marshal cleaned %s: %w", columnName, err)
}
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(columnName, cleanedJSON).Error; err != nil {
return fmt.Errorf("update row with id %v: %w", row["id"], err)
}
updatedCount++
}
}
if updatedCount > 0 {
log.WithContext(ctx).Infof("Cleaned up orphaned %s in %d rows from table %s", columnName, updatedCount, tableName)
} else {
log.WithContext(ctx).Debugf("No orphaned %s found in table %s", columnName, tableName)
}
return nil
}); err != nil {
return err
}
log.WithContext(ctx).Infof("Cleanup of orphaned %s from table %s completed", columnName, tableName)
return nil
}

View File

@@ -119,7 +119,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePreAuto: %w", err)
}
err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.GroupUser{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
@@ -177,7 +177,8 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
generateAccountSQLTypes(account)
// Encrypt sensitive user data before saving
for i := range account.UsersG {
for i, user := range account.UsersG {
user.StoreAutoGroups()
if err := account.UsersG[i].EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
}
@@ -203,15 +204,35 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
return result.Error
}
// Save account without UsersG.Groups to avoid FK constraint violations
// (groups must exist before group_users can reference them)
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Omit("UsersG.Groups").
Clauses(clause.OnConflict{UpdateAll: true}).
Create(account)
if result.Error != nil {
return result.Error
}
// Now save the user-group associations after both users and groups exist
for _, user := range account.UsersG {
if len(user.Groups) > 0 {
result = tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "user_id"}},
UpdateAll: true,
}).Create(&user.Groups)
if result.Error != nil {
return result.Error
}
}
}
return nil
})
if err != nil {
return err
}
took := time.Since(start)
if s.metrics != nil {
@@ -219,7 +240,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
}
log.WithContext(ctx).Debugf("took %d ms to persist an account to the store", took.Milliseconds())
return err
return nil
}
// generateAccountSQLTypes generates the GORM compatible types for the account
@@ -243,7 +264,7 @@ func generateAccountSQLTypes(account *types.Account) {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
account.UsersG = append(account.UsersG, user)
}
for id, group := range account.Groups {
@@ -453,6 +474,7 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
userCopy := user.Copy()
userCopy.Email = user.Email
userCopy.Name = user.Name
userCopy.StoreAutoGroups()
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
}
@@ -472,16 +494,37 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
userCopy := user.Copy()
userCopy.Email = user.Email
userCopy.Name = user.Name
userCopy.StoreAutoGroups()
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
}
result := s.db.Save(userCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save user to store")
err := s.transaction(func(tx *gorm.DB) error {
result := tx.Omit("Groups").Save(userCopy)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
}
result = tx.Delete(&types.GroupUser{}, "user_id = ?", user.Id)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to delete user groups from store: %v", result.Error)
}
if len(userCopy.Groups) != 0 {
result = tx.Save(userCopy.Groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save user groups to store: %v", result.Error)
}
}
return nil
})
if err != nil {
log.WithContext(ctx).Errorf("failed to save user to store: %s", err)
return err
}
return nil
}
@@ -617,6 +660,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
var user types.User
result := tx.
Preload("Groups").
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
Where("personal_access_tokens.id = ?", patID).Take(&user)
if result.Error != nil {
@@ -631,6 +675,8 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
return &user, nil
}
@@ -641,7 +687,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
}
var user types.User
result := tx.Take(&user, idQueryCondition, userID)
result := tx.Preload("Groups").Take(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@@ -653,6 +699,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
return &user, nil
}
@@ -680,7 +728,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
}
var users []*types.User
result := tx.Find(&users, accountIDCondition, accountID)
result := tx.Preload("Groups").Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -693,6 +741,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
}
return users, nil
@@ -705,7 +754,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
}
var user types.User
result := tx.Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
result := tx.Preload("Groups").Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed")
@@ -717,6 +766,8 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
return &user, nil
}
@@ -867,7 +918,10 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
Preload("SetupKeysG").
Preload("PeersG").
Preload("UsersG").
Preload("UsersG.Groups").
Preload("GroupsG").
Preload("GroupsG.GroupPeers").
Preload("GroupsG.GroupUsers").
Preload("RoutesG").
Preload("NameServerGroupsG").
Preload("PostureChecks").
@@ -908,13 +962,14 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
pat.UserID = ""
user.PATs[pat.ID] = &pat
}
if user.AutoGroups == nil {
user.AutoGroups = []string{}
if user.Groups == nil {
user.Groups = []*types.GroupUser{}
}
user.LoadAutoGroups()
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
account.Users[user.Id] = &user
account.Users[user.Id] = user
user.PATsG = nil
}
account.UsersG = nil
@@ -1116,8 +1171,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
groupIDs = append(groupIDs, g.ID)
}
wg.Add(3)
errChan = make(chan error, 3)
wg.Add(4)
errChan = make(chan error, 4)
var pats []types.PersonalAccessToken
go func() {
@@ -1149,6 +1204,16 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
}
}()
var groupUsers []types.GroupUser
go func() {
defer wg.Done()
var err error
groupUsers, err = s.getGroupUsers(ctx, userIDs)
if err != nil {
errChan <- err
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
@@ -1174,6 +1239,12 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
}
groupsByUserID := make(map[string][]*types.GroupUser)
for i := range groupUsers {
gu := &groupUsers[i]
groupsByUserID[gu.UserID] = append(groupsByUserID[gu.UserID], gu)
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for i := range account.SetupKeysG {
key := &account.SetupKeysG[i]
@@ -1188,7 +1259,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
account.Users = make(map[string]*types.User, len(account.UsersG))
for i := range account.UsersG {
user := &account.UsersG[i]
user := account.UsersG[i]
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
@@ -1199,6 +1270,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
user.PATs[pat.ID] = pat
}
}
user.Groups = groupsByUserID[user.Id]
user.LoadAutoGroups()
account.Users[user.Id] = user
}
@@ -1595,44 +1668,41 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
return peers, nil
}
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]*types.User, error) {
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
}
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.User, error) {
var u types.User
var autoGroups []byte
var lastLogin, createdAt sql.NullTime
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
if err == nil {
if lastLogin.Valid {
u.LastLogin = &lastLogin.Time
}
if createdAt.Valid {
u.CreatedAt = createdAt.Time
}
if isServiceUser.Valid {
u.IsServiceUser = isServiceUser.Bool
}
if nonDeletable.Valid {
u.NonDeletable = nonDeletable.Bool
}
if blocked.Valid {
u.Blocked = blocked.Bool
}
if pendingApproval.Valid {
u.PendingApproval = pendingApproval.Bool
}
if autoGroups != nil {
_ = json.Unmarshal(autoGroups, &u.AutoGroups)
} else {
u.AutoGroups = []string{}
}
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
if err != nil {
return &u, err
}
return u, err
if lastLogin.Valid {
u.LastLogin = &lastLogin.Time
}
if createdAt.Valid {
u.CreatedAt = createdAt.Time
}
if isServiceUser.Valid {
u.IsServiceUser = isServiceUser.Bool
}
if nonDeletable.Valid {
u.NonDeletable = nonDeletable.Bool
}
if blocked.Valid {
u.Blocked = blocked.Bool
}
if pendingApproval.Valid {
u.PendingApproval = pendingApproval.Bool
}
return &u, nil
})
if err != nil {
return nil, err
@@ -2038,6 +2108,22 @@ func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]type
return groupPeers, nil
}
func (s *SqlStore) getGroupUsers(ctx context.Context, userIDs []string) ([]types.GroupUser, error) {
if len(userIDs) == 0 {
return nil, nil
}
const query = `SELECT account_id, group_id, user_id FROM group_users WHERE user_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, userIDs)
if err != nil {
return nil, err
}
groupUsers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupUser])
if err != nil {
return nil, err
}
return groupUsers, nil
}
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
var user types.User
result := s.db.Select("account_id").Take(&user, idQueryCondition, userID)
@@ -2659,6 +2745,41 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group
return nil
}
func (s *SqlStore) AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error {
user := &types.GroupUser{
AccountID: accountID,
GroupID: groupID,
UserID: userID,
}
err := s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "user_id"}},
DoNothing: true,
}).Create(user).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to add user %s to group %s for account %s: %v", userID, groupID, accountID, err)
return status.Errorf(status.Internal, "failed to add user to group")
}
return nil
}
func (s *SqlStore) RemoveUserFromGroup(ctx context.Context, userID, groupID string) error {
result := s.db.Delete(&types.GroupUser{}, "group_id = ? AND user_id = ?", groupID, userID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to remove user %s from group %s: %v", userID, groupID, result.Error)
return status.Errorf(status.Internal, "failed to remove user from group")
}
if result.RowsAffected == 0 {
log.WithContext(ctx).Warnf("user %s was not in group %s", userID, groupID)
}
return nil
}
// RemovePeerFromAllGroups removes a peer from all groups
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
err := s.db.

View File

@@ -1372,6 +1372,7 @@ func TestSqlStore_CreateGroup(t *testing.T) {
Peers: []string{},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
GroupUsers: []types.GroupUser{},
}
err = store.CreateGroup(context.Background(), group)
require.NoError(t, err)
@@ -1396,6 +1397,7 @@ func TestSqlStore_CreateUpdateGroups(t *testing.T) {
Peers: []string{},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
GroupUsers: []types.GroupUser{},
},
{
ID: "group-2",
@@ -1404,6 +1406,7 @@ func TestSqlStore_CreateUpdateGroups(t *testing.T) {
Peers: []string{},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
GroupUsers: []types.GroupUser{},
},
}
err = store.CreateGroups(context.Background(), accountID, groups)
@@ -3059,7 +3062,7 @@ func TestSqlStore_SaveUser(t *testing.T) {
AccountID: accountID,
Role: types.UserRoleAdmin,
IsServiceUser: false,
AutoGroups: []string{"groupA", "groupB"},
AutoGroups: []string{"cfefqs706sqkneg59g2g", "cfefqs706sqkneg59g3g"},
Blocked: false,
LastLogin: util.ToPtr(time.Now().UTC()),
CreatedAt: time.Now().UTC().Add(-time.Hour),
@@ -3097,13 +3100,13 @@ func TestSqlStore_SaveUsers(t *testing.T) {
Id: "user-1",
AccountID: accountID,
Issued: "api",
AutoGroups: []string{"groupA", "groupB"},
AutoGroups: []string{"cfefqs706sqkneg59g2g", "cfefqs706sqkneg59g3g"},
},
{
Id: "user-2",
AccountID: accountID,
Issued: "integration",
AutoGroups: []string{"groupA"},
AutoGroups: []string{"cfefqs706sqkneg59g2g"},
},
}
err = store.SaveUsers(context.Background(), users)
@@ -3113,7 +3116,7 @@ func TestSqlStore_SaveUsers(t *testing.T) {
require.NoError(t, err)
require.Len(t, accountUsers, 4)
users[1].AutoGroups = []string{"groupA", "groupC"}
users[1].AutoGroups = []string{"cfefqs706sqkneg59g2g", "cfefqs706sqkneg59g4g"}
err = store.SaveUsers(context.Background(), users)
require.NoError(t, err)
@@ -3151,7 +3154,7 @@ func TestSqlStore_SaveUserWithEncryption(t *testing.T) {
Role: types.UserRoleUser,
Email: "",
Name: "",
AutoGroups: []string{"groupA"},
AutoGroups: []string{"cfefqs706sqkneg59g2g"},
}
err = store.SaveUser(context.Background(), user)
require.NoError(t, err)
@@ -3180,7 +3183,7 @@ func TestSqlStore_SaveUserWithEncryption(t *testing.T) {
Role: types.UserRoleAdmin,
Email: "test@example.com",
Name: "Test User",
AutoGroups: []string{"groupB"},
AutoGroups: []string{"cfefqs706sqkneg59g3g"},
}
err = store.SaveUser(context.Background(), user)
require.NoError(t, err)

View File

@@ -82,6 +82,7 @@ func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
user.LoadAutoGroups()
account.Users[user.Id] = user.Copy()
}
account.UsersG = nil
@@ -89,6 +90,9 @@ func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
account.Groups[group.ID] = group.Copy()
if len(group.GroupUsers) == 0 {
account.Groups[group.ID] = nil
}
}
account.GroupsG = nil
@@ -175,10 +179,12 @@ func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*ty
pat.UserID = ""
user.PATs[pat.ID] = &pat
}
if user.AutoGroups == nil {
user.LoadAutoGroups()
if len(user.AutoGroups) == 0 {
user.AutoGroups = []string{}
user.Groups = []*types.GroupUser{}
}
account.Users[user.Id] = &user
account.Users[user.Id] = user
user.PATsG = nil
}
account.UsersG = nil
@@ -191,6 +197,9 @@ func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*ty
if group.Resources == nil {
group.Resources = []types.Resource{}
}
if group.GroupUsers == nil {
group.GroupUsers = []types.GroupUser{}
}
account.Groups[group.ID] = group
}
account.GroupsG = nil
@@ -259,7 +268,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
models := []interface{}{
&types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{},
&types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.GroupUser{},
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
@@ -609,10 +618,12 @@ func testAccountEquivalence(t *testing.T, expected, actual *types.Account) {
assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements")
for key, oldVal := range expected.Groups {
newVal, ok := actual.Groups[key]
if oldVal != nil && newVal != nil {
sort.Strings(oldVal.Peers)
sort.Strings(newVal.Peers)
assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key)
}
assert.True(t, ok, "Group with ID '%s' should exist in new account", key)
sort.Strings(oldVal.Peers)
sort.Strings(newVal.Peers)
assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key)
}
assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements")
@@ -900,7 +911,7 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty
account.Users = make(map[string]*types.User, len(account.UsersG))
for i := range account.UsersG {
user := &account.UsersG[i]
user := account.UsersG[i]
user.PATs = make(map[string]*types.PersonalAccessToken)
if userPats, ok := patsByUserID[user.Id]; ok {
for j := range userPats {

View File

@@ -89,6 +89,8 @@ type Store interface {
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error
RemoveUserFromGroup(ctx context.Context, userID, groupID string) error
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
@@ -353,6 +355,9 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.RemoveDuplicatePeerKeys(ctx, db)
},
func(db *gorm.DB) error {
return migration.CleanupOrphanedIDs[types.User, types.Group](ctx, db, "auto_groups")
},
}
}
@@ -392,6 +397,15 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key")
},
func(db *gorm.DB) error {
return migration.MigrateJsonToTable[types.User](ctx, db, "auto_groups", func(accountID, id, value string) any {
return &types.GroupUser{
AccountID: accountID,
GroupID: value,
UserID: id,
}
})
},
}
}

View File

@@ -87,7 +87,7 @@ type Account struct {
Peers map[string]*nbpeer.Peer `gorm:"-"`
PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
Users map[string]*User `gorm:"-"`
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
UsersG []*User `json:"-" gorm:"foreignKey:AccountID;references:id"`
Groups map[string]*Group `gorm:"-"`
GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`

View File

@@ -28,6 +28,7 @@ type Group struct {
// Peers list of the group
Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
GroupUsers []GroupUser `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
// Resources contains a list of resources in that group
Resources []Resource `gorm:"serializer:json"`
@@ -41,6 +42,20 @@ type GroupPeer struct {
PeerID string `gorm:"primaryKey"`
}
type GroupUser struct {
AccountID string `gorm:"index"`
GroupID string `gorm:"primaryKey"`
UserID string `gorm:"primaryKey"`
}
func (g *GroupUser) Copy() *GroupUser {
return &GroupUser{
AccountID: g.AccountID,
GroupID: g.GroupID,
UserID: g.UserID,
}
}
func (g *Group) LoadGroupPeers() {
g.Peers = make([]string, len(g.GroupPeers))
for i, peer := range g.GroupPeers {
@@ -78,11 +93,13 @@ func (g *Group) Copy() *Group {
Issued: g.Issued,
Peers: make([]string, len(g.Peers)),
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
GroupUsers: make([]GroupUser, len(g.GroupUsers)),
Resources: make([]Resource, len(g.Resources)),
IntegrationReference: g.IntegrationReference,
}
copy(group.Peers, g.Peers)
copy(group.GroupPeers, g.GroupPeers)
copy(group.GroupUsers, g.GroupUsers)
copy(group.Resources, g.Resources)
return group
}

View File

@@ -85,9 +85,11 @@ type User struct {
// ServiceUserName is only set if IsServiceUser is true
ServiceUserName string
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string `gorm:"serializer:json"`
PATs map[string]*PersonalAccessToken `gorm:"-"`
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
AutoGroups []string `gorm:"-"`
// GroupUsers replaces old AutoGroups
Groups []*GroupUser `gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
PATs map[string]*PersonalAccessToken `gorm:"-"`
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
// PendingApproval indicates whether the user requires approval before being activated
@@ -106,6 +108,26 @@ type User struct {
Email string `gorm:"default:''"`
}
func (u *User) LoadAutoGroups() {
u.AutoGroups = make([]string, 0, len(u.Groups))
for _, group := range u.Groups {
u.AutoGroups = append(u.AutoGroups, group.GroupID)
}
u.Groups = []*GroupUser{}
}
func (u *User) StoreAutoGroups() {
u.Groups = make([]*GroupUser, 0, len(u.AutoGroups))
for _, groupID := range u.AutoGroups {
u.Groups = append(u.Groups, &GroupUser{
AccountID: u.AccountID,
GroupID: groupID,
UserID: u.Id,
})
}
u.AutoGroups = []string{}
}
// IsBlocked returns true if the user is blocked, false otherwise
func (u *User) IsBlocked() bool {
return u.Blocked
@@ -198,8 +220,20 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
// Copy the user
func (u *User) Copy() *User {
autoGroups := make([]string, len(u.AutoGroups))
copy(autoGroups, u.AutoGroups)
var groupUsers []*GroupUser
if u.Groups != nil {
groupUsers = make([]*GroupUser, len(u.Groups))
for i, groupUser := range u.Groups {
groupUsers[i] = groupUser.Copy()
}
}
var autoGroups []string
if u.AutoGroups != nil {
autoGroups = make([]string, len(u.AutoGroups))
copy(autoGroups, u.AutoGroups)
}
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
for k, v := range u.PATs {
pats[k] = v.Copy()
@@ -221,6 +255,7 @@ func (u *User) Copy() *User {
IntegrationReference: u.IntegrationReference,
Email: u.Email,
Name: u.Name,
Groups: groupUsers,
}
}

View File

@@ -45,7 +45,21 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
newUser.AccountID = accountID
log.WithContext(ctx).Debugf("New User: %v", newUser)
if err = am.Store.SaveUser(ctx, newUser); err != nil {
if err = am.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
err = tx.SaveUser(ctx, newUser)
if err != nil {
return err
}
for _, groupID := range autoGroups {
err = tx.AddUserToGroup(ctx, accountID, newUserID, groupID)
if err != nil {
return fmt.Errorf("failed to add user to group %s: %w", groupID, err)
}
}
return nil
}); err != nil {
return nil, err
}
@@ -119,7 +133,6 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
Id: idpUser.ID,
AccountID: accountID,
Role: types.StrRoleToUserRole(invite.Role),
AutoGroups: invite.AutoGroups,
Issued: invite.Issued,
IntegrationReference: invite.IntegrationReference,
CreatedAt: time.Now().UTC(),
@@ -127,6 +140,23 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
Name: invite.Name,
}
err = am.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
err = tx.SaveUser(ctx, newUser)
if err != nil {
return err
}
for _, group := range invite.AutoGroups {
err = tx.AddUserToGroup(ctx, accountID, userID, group)
if err != nil {
return fmt.Errorf("failed to add user to group %s: %w", group, err)
}
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to save user: %w", err)
}
if err = am.Store.SaveUser(ctx, newUser); err != nil {
return nil, err
}
@@ -715,6 +745,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
updatedUser.Role = update.Role
updatedUser.Blocked = update.Blocked
updatedUser.AutoGroups = update.AutoGroups
updatedUser.StoreAutoGroups()
// these two fields can't be set via API, only via direct call to the method
updatedUser.Issued = update.Issued
updatedUser.IntegrationReference = update.IntegrationReference
@@ -737,28 +768,58 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
peersToExpire = userPeers
}
var removedGroups, addedGroups []string
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
removedGroups = util.Difference(oldUser.AutoGroups, update.AutoGroups)
addedGroups = util.Difference(update.AutoGroups, oldUser.AutoGroups)
updateAccountPeers, removedGroupsIDs, addedGroupsIDs, err := am.processUserGroupsUpdate(ctx, transaction, oldUser, updatedUser, userPeers, settings)
if err != nil {
return false, nil, nil, nil, err
}
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroupsIDs, addedGroupsIDs, transaction)
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
}
func (am *DefaultAccountManager) processUserGroupsUpdate(ctx context.Context, transaction store.Store, oldUser *types.User, updatedUser *types.User, userPeers []*nbpeer.Peer, settings *types.Settings) (bool, []string, []string, error) {
removedGroups := util.Difference(oldUser.AutoGroups, updatedUser.AutoGroups)
addedGroups := util.Difference(updatedUser.AutoGroups, oldUser.AutoGroups)
updateAccountPeers := len(userPeers) > 0
removedGroupsIDs := make([]string, 0, len(removedGroups))
for _, id := range removedGroups {
err := transaction.RemoveUserFromGroup(ctx, updatedUser.Id, id)
if err != nil {
return false, nil, nil, fmt.Errorf("failed to remove user %s from group %s: %w", updatedUser.Id, id, err)
}
updateAccountPeers = true
removedGroupsIDs = append(removedGroupsIDs, id)
}
addedGroupsIDs := make([]string, 0, len(addedGroups))
for _, id := range addedGroups {
err := transaction.AddUserToGroup(ctx, updatedUser.AccountID, updatedUser.Id, id)
if err != nil {
return false, nil, nil, fmt.Errorf("failed to add user %s to group %s: %w", updatedUser.Id, id, err)
}
updateAccountPeers = true
addedGroupsIDs = append(addedGroupsIDs, id)
}
if updatedUser.Groups != nil && settings.GroupsPropagationEnabled {
for _, peer := range userPeers {
for _, groupID := range removedGroups {
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil {
return false, nil, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, groupID, err)
for _, id := range removedGroups {
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, id); err != nil {
return false, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, id, err)
}
}
for _, groupID := range addedGroups {
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
return false, nil, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, groupID, err)
for _, id := range addedGroups {
if err := transaction.AddPeerToGroup(ctx, updatedUser.AccountID, peer.ID, id); err != nil {
return false, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, id, err)
}
}
}
}
updateAccountPeers := len(userPeers) > 0
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroups, addedGroups, transaction)
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
return updateAccountPeers, removedGroupsIDs, addedGroupsIDs, nil
}
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.

View File

@@ -345,6 +345,9 @@ func TestUser_Copy(t *testing.T) {
IsServiceUser: true,
ServiceUserName: "servicename",
AutoGroups: []string{"group1", "group2"},
Groups: []*types.GroupUser{
{AccountID: "accountId", GroupID: "groupId", UserID: "userId"},
},
PATs: map[string]*types.PersonalAccessToken{
"pat1": {
ID: "pat1",
@@ -413,6 +416,14 @@ func TestUser_CreateServiceUser(t *testing.T) {
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
account.Groups["group1"] = &types.Group{
ID: "group1",
Name: "group1",
}
account.Groups["group2"] = &types.Group{
ID: "group2",
Name: "group2",
}
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -460,6 +471,14 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
account.Groups["group1"] = &types.Group{
ID: "group1",
Name: "group1",
}
account.Groups["group2"] = &types.Group{
ID: "group2",
Name: "group2",
}
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -539,6 +558,14 @@ func TestUser_InviteNewUser(t *testing.T) {
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
account.Groups["group1"] = &types.Group{
ID: "group1",
Name: "group1",
}
account.Groups["group2"] = &types.Group{
ID: "group2",
Name: "group2",
}
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -1653,6 +1680,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
AutoGroups: []string{},
},
Permissions: mergeRolePermissions(roles.User),
},
@@ -1672,6 +1700,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
AutoGroups: []string{},
},
Permissions: mergeRolePermissions(roles.Admin),
},
@@ -1691,6 +1720,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
AutoGroups: []string{},
},
Permissions: mergeRolePermissions(roles.User),
Restricted: true,
@@ -1712,6 +1742,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
LastLogin: time.Time{},
Issued: "api",
IntegrationReference: integration_reference.IntegrationReference{},
AutoGroups: []string{},
},
Permissions: mergeRolePermissions(roles.User),
Restricted: false,