mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 09:03:54 -04:00
[management] Remove save account calls (#4349)
This commit is contained in:
@@ -1952,20 +1952,19 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
|||||||
return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain")
|
return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error {
|
||||||
var account *types.Account
|
|
||||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
var err error
|
var err error
|
||||||
account, err = transaction.GetAccount(ctx, accountId)
|
ok, domain, err := transaction.IsPrimaryAccount(ctx, accountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.IsDomainPrimaryAccount {
|
if ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain)
|
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
|
||||||
|
|
||||||
// error is not a not found error
|
// error is not a not found error
|
||||||
if handleNotFound(err) != nil {
|
if handleNotFound(err) != nil {
|
||||||
@@ -1981,9 +1980,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
|||||||
return status.Errorf(status.Internal, "cannot update account to primary")
|
return status.Errorf(status.Internal, "cannot update account to primary")
|
||||||
}
|
}
|
||||||
|
|
||||||
account.IsDomainPrimaryAccount = true
|
if err := transaction.MarkAccountPrimary(ctx, accountId); err != nil {
|
||||||
|
|
||||||
if err := transaction.SaveAccount(ctx, account); err != nil {
|
|
||||||
log.WithContext(ctx).WithFields(log.Fields{
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
"accountId": accountId,
|
"accountId": accountId,
|
||||||
}).Errorf("failed to update account to primary: %v", err)
|
}).Errorf("failed to update account to primary: %v", err)
|
||||||
@@ -1993,10 +1990,10 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
|
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
|
||||||
@@ -2067,14 +2064,12 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
|
|||||||
Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()),
|
Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()),
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := transaction.GetAccount(ctx, accountID)
|
err := transaction.UpdateAccountNetwork(ctx, accountID, newIPNet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Network.Net = newIPNet
|
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthUpdate, accountID, "", "")
|
||||||
|
|
||||||
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -2094,10 +2089,6 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
|
|||||||
takenIPs = append(takenIPs, newIP)
|
takenIPs = append(takenIPs, newIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||||
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)
|
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
@@ -18,6 +17,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ExternalCacheManager nbcache.UserDataCache
|
type ExternalCacheManager nbcache.UserDataCache
|
||||||
@@ -120,7 +120,7 @@ type Manager interface {
|
|||||||
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||||
GetStore() store.Store
|
GetStore() store.Store
|
||||||
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
|
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
|
||||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3250,11 +3250,13 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
|
|||||||
assert.Equal(t, 0, len(account2.Users))
|
assert.Equal(t, 0, len(account2.Users))
|
||||||
assert.Equal(t, 0, len(account2.SetupKeys))
|
assert.Equal(t, 0, len(account2.SetupKeys))
|
||||||
|
|
||||||
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
account, err = manager.Store.GetAccount(ctx, account.Id)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, account.IsDomainPrimaryAccount)
|
assert.True(t, account.IsDomainPrimaryAccount)
|
||||||
|
|
||||||
_, err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
|
err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
|
||||||
assert.Error(t, err, "should not be able to update a second account to primary")
|
assert.Error(t, err, "should not be able to update a second account to primary")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3275,7 +3277,9 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
|
|||||||
assert.False(t, account.IsDomainPrimaryAccount)
|
assert.False(t, account.IsDomainPrimaryAccount)
|
||||||
assert.Equal(t, domain, account.Domain)
|
assert.Equal(t, domain, account.Domain)
|
||||||
|
|
||||||
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
account, err = manager.Store.GetAccount(ctx, account.Id)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.True(t, account.IsDomainPrimaryAccount)
|
assert.True(t, account.IsDomainPrimaryAccount)
|
||||||
|
|
||||||
|
|||||||
@@ -50,23 +50,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context,
|
|||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
a, err := transaction.GetAccount(ctx, accountID)
|
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var extra *types.ExtraSettings
|
var extra *types.ExtraSettings
|
||||||
|
|
||||||
if a.Settings.Extra != nil {
|
if settings.Extra != nil {
|
||||||
extra = a.Settings.Extra
|
extra = settings.Extra
|
||||||
} else {
|
} else {
|
||||||
extra = &types.ExtraSettings{}
|
extra = &types.ExtraSettings{}
|
||||||
a.Settings.Extra = extra
|
settings.Extra = extra
|
||||||
}
|
}
|
||||||
|
|
||||||
extra.IntegratedValidator = validator
|
extra.IntegratedValidator = validator
|
||||||
extra.IntegratedValidatorGroups = groups
|
extra.IntegratedValidatorGroups = groups
|
||||||
return transaction.SaveAccount(ctx, a)
|
return transaction.SaveAccountSettings(ctx, accountID, settings)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
@@ -21,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/users"
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ account.Manager = (*MockAccountManager)(nil)
|
var _ account.Manager = (*MockAccountManager)(nil)
|
||||||
@@ -114,7 +114,7 @@ type MockAccountManager struct {
|
|||||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||||
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||||
GetStoreFunc func() store.Store
|
GetStoreFunc func() store.Store
|
||||||
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
|
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error
|
||||||
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
|
||||||
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||||
@@ -933,11 +933,11 @@ func (am *MockAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Cont
|
|||||||
return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented")
|
return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
|
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error {
|
||||||
if am.UpdateToPrimaryAccountFunc != nil {
|
if am.UpdateToPrimaryAccountFunc != nil {
|
||||||
return am.UpdateToPrimaryAccountFunc(ctx, accountId)
|
return am.UpdateToPrimaryAccountFunc(ctx, accountId)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented")
|
return status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) {
|
func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) {
|
||||||
|
|||||||
@@ -2832,3 +2832,57 @@ func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFu
|
|||||||
}()
|
}()
|
||||||
return ctx, cancel
|
return ctx, cancel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
|
||||||
|
var info types.PrimaryAccountInfo
|
||||||
|
result := s.db.Model(&types.Account{}).
|
||||||
|
Select("is_domain_primary_account, domain").
|
||||||
|
Where(idQueryCondition, accountID).
|
||||||
|
Take(&info)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
return false, "", status.Errorf(status.Internal, "failed to get account info: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return info.IsDomainPrimaryAccount, info.Domain, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) MarkAccountPrimary(ctx context.Context, accountID string) error {
|
||||||
|
result := s.db.Model(&types.Account{}).
|
||||||
|
Where(idQueryCondition, accountID).
|
||||||
|
Update("is_domain_primary_account", true)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to mark account as primary: %s", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to mark account as primary")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewAccountNotFoundError(accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type accountNetworkPatch struct {
|
||||||
|
Network *types.Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error {
|
||||||
|
patch := accountNetworkPatch{
|
||||||
|
Network: &types.Network{Net: ipNet},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).
|
||||||
|
Model(&types.Account{}).
|
||||||
|
Where(idQueryCondition, accountID).
|
||||||
|
Updates(&patch)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to update account network: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to update account network")
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewAccountNotFoundError(accountID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -202,6 +202,9 @@ type Store interface {
|
|||||||
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
|
||||||
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
|
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
|
||||||
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)
|
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)
|
||||||
|
IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error)
|
||||||
|
MarkAccountPrimary(ctx context.Context, accountID string) error
|
||||||
|
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -16,16 +16,16 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -89,6 +89,12 @@ type Account struct {
|
|||||||
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
|
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this class is used by gorm only
|
||||||
|
type PrimaryAccountInfo struct {
|
||||||
|
IsDomainPrimaryAccount bool
|
||||||
|
Domain string
|
||||||
|
}
|
||||||
|
|
||||||
// Subclass used in gorm to only load network and not whole account
|
// Subclass used in gorm to only load network and not whole account
|
||||||
type AccountNetwork struct {
|
type AccountNetwork struct {
|
||||||
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||||
|
|||||||
Reference in New Issue
Block a user