[management] Remove save account calls (#4349)

This commit is contained in:
Pascal Fischer
2025-08-18 12:37:20 +02:00
committed by GitHub
parent 7cd5dcae59
commit 6a3846a8b7
8 changed files with 92 additions and 34 deletions

View File

@@ -1952,20 +1952,19 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain") return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain")
} }
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error {
var account *types.Account
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error var err error
account, err = transaction.GetAccount(ctx, accountId) ok, domain, err := transaction.IsPrimaryAccount(ctx, accountId)
if err != nil { if err != nil {
return err return err
} }
if account.IsDomainPrimaryAccount { if ok {
return nil return nil
} }
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain) existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
// error is not a not found error // error is not a not found error
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
@@ -1981,9 +1980,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return status.Errorf(status.Internal, "cannot update account to primary") return status.Errorf(status.Internal, "cannot update account to primary")
} }
account.IsDomainPrimaryAccount = true if err := transaction.MarkAccountPrimary(ctx, accountId); err != nil {
if err := transaction.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).WithFields(log.Fields{ log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId, "accountId": accountId,
}).Errorf("failed to update account to primary: %v", err) }).Errorf("failed to update account to primary: %v", err)
@@ -1993,10 +1990,10 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return nil return nil
}) })
if err != nil { if err != nil {
return nil, err return err
} }
return account, nil return nil
} }
// propagateUserGroupMemberships propagates all account users' group memberships to their peers. // propagateUserGroupMemberships propagates all account users' group memberships to their peers.
@@ -2067,14 +2064,12 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()), Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()),
} }
account, err := transaction.GetAccount(ctx, accountID) err := transaction.UpdateAccountNetwork(ctx, accountID, newIPNet)
if err != nil { if err != nil {
return err return err
} }
account.Network.Net = newIPNet peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthUpdate, accountID, "", "")
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return err return err
} }
@@ -2094,10 +2089,6 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
takenIPs = append(takenIPs, newIP) takenIPs = append(takenIPs, newIP)
} }
if err = transaction.SaveAccount(ctx, account); err != nil {
return err
}
for _, peer := range peers { for _, peer := range peers {
if err = transaction.SavePeer(ctx, accountID, peer); err != nil { if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err) return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)

View File

@@ -7,7 +7,6 @@ import (
"time" "time"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
@@ -18,6 +17,7 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
type ExternalCacheManager nbcache.UserDataCache type ExternalCacheManager nbcache.UserDataCache
@@ -120,7 +120,7 @@ type Manager interface {
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
GetStore() store.Store GetStore() store.Store
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) UpdateToPrimaryAccount(ctx context.Context, accountId string) error
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
} }

View File

@@ -3250,11 +3250,13 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
assert.Equal(t, 0, len(account2.Users)) assert.Equal(t, 0, len(account2.Users))
assert.Equal(t, 0, len(account2.SetupKeys)) assert.Equal(t, 0, len(account2.SetupKeys))
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id) err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
account, err = manager.Store.GetAccount(ctx, account.Id)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount) assert.True(t, account.IsDomainPrimaryAccount)
_, err = manager.UpdateToPrimaryAccount(ctx, account2.Id) err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
assert.Error(t, err, "should not be able to update a second account to primary") assert.Error(t, err, "should not be able to update a second account to primary")
} }
@@ -3275,7 +3277,9 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
assert.False(t, account.IsDomainPrimaryAccount) assert.False(t, account.IsDomainPrimaryAccount)
assert.Equal(t, domain, account.Domain) assert.Equal(t, domain, account.Domain)
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id) err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
account, err = manager.Store.GetAccount(ctx, account.Id)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount) assert.True(t, account.IsDomainPrimaryAccount)

View File

@@ -50,23 +50,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context,
defer unlock() defer unlock()
return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
a, err := transaction.GetAccount(ctx, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
return err return err
} }
var extra *types.ExtraSettings var extra *types.ExtraSettings
if a.Settings.Extra != nil { if settings.Extra != nil {
extra = a.Settings.Extra extra = settings.Extra
} else { } else {
extra = &types.ExtraSettings{} extra = &types.ExtraSettings{}
a.Settings.Extra = extra settings.Extra = extra
} }
extra.IntegratedValidator = validator extra.IntegratedValidator = validator
extra.IntegratedValidatorGroups = groups extra.IntegratedValidatorGroups = groups
return transaction.SaveAccount(ctx, a) return transaction.SaveAccountSettings(ctx, accountID, settings)
}) })
} }

View File

@@ -10,7 +10,6 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
@@ -21,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
var _ account.Manager = (*MockAccountManager)(nil) var _ account.Manager = (*MockAccountManager)(nil)
@@ -114,7 +114,7 @@ type MockAccountManager struct {
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
GetStoreFunc func() store.Store GetStoreFunc func() store.Store
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
@@ -933,11 +933,11 @@ func (am *MockAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Cont
return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented") return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented")
} }
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error {
if am.UpdateToPrimaryAccountFunc != nil { if am.UpdateToPrimaryAccountFunc != nil {
return am.UpdateToPrimaryAccountFunc(ctx, accountId) return am.UpdateToPrimaryAccountFunc(ctx, accountId)
} }
return nil, status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented") return status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented")
} }
func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) { func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) {

View File

@@ -2832,3 +2832,57 @@ func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFu
}() }()
return ctx, cancel return ctx, cancel
} }
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
var info types.PrimaryAccountInfo
result := s.db.Model(&types.Account{}).
Select("is_domain_primary_account, domain").
Where(idQueryCondition, accountID).
Take(&info)
if result.Error != nil {
return false, "", status.Errorf(status.Internal, "failed to get account info: %v", result.Error)
}
return info.IsDomainPrimaryAccount, info.Domain, nil
}
func (s *SqlStore) MarkAccountPrimary(ctx context.Context, accountID string) error {
result := s.db.Model(&types.Account{}).
Where(idQueryCondition, accountID).
Update("is_domain_primary_account", true)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark account as primary: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark account as primary")
}
if result.RowsAffected == 0 {
return status.NewAccountNotFoundError(accountID)
}
return nil
}
type accountNetworkPatch struct {
Network *types.Network `gorm:"embedded;embeddedPrefix:network_"`
}
func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error {
patch := accountNetworkPatch{
Network: &types.Network{Net: ipNet},
}
result := s.db.WithContext(ctx).
Model(&types.Account{}).
Where(idQueryCondition, accountID).
Updates(&patch)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update account network: %v", result.Error)
return status.Errorf(status.Internal, "failed to update account network")
}
if result.RowsAffected == 0 {
return status.NewAccountNotFoundError(accountID)
}
return nil
}

View File

@@ -202,6 +202,9 @@ type Store interface {
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)
IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error)
MarkAccountPrimary(ctx context.Context, accountID string) error
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
} }
const ( const (

View File

@@ -16,16 +16,16 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -89,6 +89,12 @@ type Account struct {
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
} }
// this class is used by gorm only
type PrimaryAccountInfo struct {
IsDomainPrimaryAccount bool
Domain string
}
// Subclass used in gorm to only load network and not whole account // Subclass used in gorm to only load network and not whole account
type AccountNetwork struct { type AccountNetwork struct {
Network *Network `gorm:"embedded;embeddedPrefix:network_"` Network *Network `gorm:"embedded;embeddedPrefix:network_"`