mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 09:03:54 -04:00
[management] Remove save account calls (#4349)
This commit is contained in:
@@ -1952,20 +1952,19 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
||||
return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain")
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
||||
var account *types.Account
|
||||
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error {
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
account, err = transaction.GetAccount(ctx, accountId)
|
||||
ok, domain, err := transaction.IsPrimaryAccount(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if account.IsDomainPrimaryAccount {
|
||||
if ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain)
|
||||
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
|
||||
|
||||
// error is not a not found error
|
||||
if handleNotFound(err) != nil {
|
||||
@@ -1981,9 +1980,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
||||
return status.Errorf(status.Internal, "cannot update account to primary")
|
||||
}
|
||||
|
||||
account.IsDomainPrimaryAccount = true
|
||||
|
||||
if err := transaction.SaveAccount(ctx, account); err != nil {
|
||||
if err := transaction.MarkAccountPrimary(ctx, accountId); err != nil {
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"accountId": accountId,
|
||||
}).Errorf("failed to update account to primary: %v", err)
|
||||
@@ -1993,10 +1990,10 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
|
||||
@@ -2067,14 +2064,12 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
|
||||
Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()),
|
||||
}
|
||||
|
||||
account, err := transaction.GetAccount(ctx, accountID)
|
||||
err := transaction.UpdateAccountNetwork(ctx, accountID, newIPNet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.Network.Net = newIPNet
|
||||
|
||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthUpdate, accountID, "", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -2094,10 +2089,6 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
|
||||
takenIPs = append(takenIPs, newIP)
|
||||
}
|
||||
|
||||
if err = transaction.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, peer := range peers {
|
||||
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
@@ -18,6 +17,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type ExternalCacheManager nbcache.UserDataCache
|
||||
@@ -120,7 +120,7 @@ type Manager interface {
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
GetStore() store.Store
|
||||
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
|
||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
|
||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
}
|
||||
|
||||
@@ -3250,11 +3250,13 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
|
||||
assert.Equal(t, 0, len(account2.Users))
|
||||
assert.Equal(t, 0, len(account2.SetupKeys))
|
||||
|
||||
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||
err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||
assert.NoError(t, err)
|
||||
account, err = manager.Store.GetAccount(ctx, account.Id)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, account.IsDomainPrimaryAccount)
|
||||
|
||||
_, err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
|
||||
err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
|
||||
assert.Error(t, err, "should not be able to update a second account to primary")
|
||||
}
|
||||
|
||||
@@ -3275,7 +3277,9 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||
assert.False(t, account.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, domain, account.Domain)
|
||||
|
||||
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||
err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||
assert.NoError(t, err)
|
||||
account, err = manager.Store.GetAccount(ctx, account.Id)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, account.IsDomainPrimaryAccount)
|
||||
|
||||
|
||||
@@ -50,23 +50,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context,
|
||||
defer unlock()
|
||||
|
||||
return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
a, err := transaction.GetAccount(ctx, accountID)
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var extra *types.ExtraSettings
|
||||
|
||||
if a.Settings.Extra != nil {
|
||||
extra = a.Settings.Extra
|
||||
if settings.Extra != nil {
|
||||
extra = settings.Extra
|
||||
} else {
|
||||
extra = &types.ExtraSettings{}
|
||||
a.Settings.Extra = extra
|
||||
settings.Extra = extra
|
||||
}
|
||||
|
||||
extra.IntegratedValidator = validator
|
||||
extra.IntegratedValidatorGroups = groups
|
||||
return transaction.SaveAccount(ctx, a)
|
||||
return transaction.SaveAccountSettings(ctx, accountID, settings)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
@@ -21,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
var _ account.Manager = (*MockAccountManager)(nil)
|
||||
@@ -114,7 +114,7 @@ type MockAccountManager struct {
|
||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
GetStoreFunc func() store.Store
|
||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
|
||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error
|
||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||
@@ -933,11 +933,11 @@ func (am *MockAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Cont
|
||||
return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
||||
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error {
|
||||
if am.UpdateToPrimaryAccountFunc != nil {
|
||||
return am.UpdateToPrimaryAccountFunc(ctx, accountId)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented")
|
||||
return status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) {
|
||||
|
||||
@@ -2832,3 +2832,57 @@ func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFu
|
||||
}()
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
|
||||
var info types.PrimaryAccountInfo
|
||||
result := s.db.Model(&types.Account{}).
|
||||
Select("is_domain_primary_account, domain").
|
||||
Where(idQueryCondition, accountID).
|
||||
Take(&info)
|
||||
|
||||
if result.Error != nil {
|
||||
return false, "", status.Errorf(status.Internal, "failed to get account info: %v", result.Error)
|
||||
}
|
||||
|
||||
return info.IsDomainPrimaryAccount, info.Domain, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) MarkAccountPrimary(ctx context.Context, accountID string) error {
|
||||
result := s.db.Model(&types.Account{}).
|
||||
Where(idQueryCondition, accountID).
|
||||
Update("is_domain_primary_account", true)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to mark account as primary: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to mark account as primary")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type accountNetworkPatch struct {
|
||||
Network *types.Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error {
|
||||
patch := accountNetworkPatch{
|
||||
Network: &types.Network{Net: ipNet},
|
||||
}
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&types.Account{}).
|
||||
Where(idQueryCondition, accountID).
|
||||
Updates(&patch)
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update account network: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update account network")
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -202,6 +202,9 @@ type Store interface {
|
||||
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
||||
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
|
||||
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)
|
||||
IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error)
|
||||
MarkAccountPrimary(ctx context.Context, accountID string) error
|
||||
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -16,16 +16,16 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -89,6 +89,12 @@ type Account struct {
|
||||
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
|
||||
}
|
||||
|
||||
// this class is used by gorm only
|
||||
type PrimaryAccountInfo struct {
|
||||
IsDomainPrimaryAccount bool
|
||||
Domain string
|
||||
}
|
||||
|
||||
// Subclass used in gorm to only load network and not whole account
|
||||
type AccountNetwork struct {
|
||||
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
|
||||
Reference in New Issue
Block a user