mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
Compare commits
32 Commits
fix/proxy_
...
deploy/pee
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c44187900 | ||
|
|
0dc050b354 | ||
|
|
69eb6f17c6 | ||
|
|
ccbe2e0100 | ||
|
|
5b042eee74 | ||
|
|
125c418d85 | ||
|
|
9f607029c0 | ||
|
|
4df1d556c8 | ||
|
|
acaf315152 | ||
|
|
81b6f7dc56 | ||
|
|
cb9a8e9fa4 | ||
|
|
f68d9ce042 | ||
|
|
1f182411c6 | ||
|
|
68e971cb82 | ||
|
|
f4a464fb69 | ||
|
|
7aff0e828b | ||
|
|
90ebf0017b | ||
|
|
bf8c6b95de | ||
|
|
8415064758 | ||
|
|
809f6cdbc7 | ||
|
|
7474876d6a | ||
|
|
2095e35217 | ||
|
|
02537e5536 | ||
|
|
f99477d3db | ||
|
|
3085383729 | ||
|
|
07bdf977d0 | ||
|
|
4a147006d2 | ||
|
|
480192028e | ||
|
|
027c0e5c63 | ||
|
|
fbb1181b64 | ||
|
|
e50ed290d9 | ||
|
|
8f8cfcbc20 |
@@ -11,6 +11,7 @@ import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -76,7 +77,7 @@ type AccountManager interface {
|
||||
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||
ListUsers(accountID string) ([]*User, error)
|
||||
GetPeers(accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error
|
||||
MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *Account) error
|
||||
DeletePeer(accountID, peerID, userID string) error
|
||||
UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
GetNetworkMap(peerID string) (*NetworkMap, error)
|
||||
@@ -117,8 +118,8 @@ type AccountManager interface {
|
||||
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
|
||||
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
|
||||
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||
HasConnectedChannel(peerID string) bool
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
@@ -130,6 +131,8 @@ type AccountManager interface {
|
||||
UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error
|
||||
GroupValidation(accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(account *Account) (map[string]struct{}, error)
|
||||
SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
|
||||
CancelPeerRoutines(peer *nbpeer.Peer) error
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@@ -386,6 +389,8 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
|
||||
|
||||
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
|
||||
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
|
||||
log.Debugf("GetNetworkMap with trace: %s", string(debug.Stack()))
|
||||
|
||||
peer := a.Peers[peerID]
|
||||
if peer == nil {
|
||||
return &NetworkMap{
|
||||
@@ -958,7 +963,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -1009,7 +1014,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
|
||||
|
||||
func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) {
|
||||
return func() (time.Duration, bool) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -1108,7 +1113,7 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
|
||||
|
||||
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
|
||||
func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
@@ -1567,7 +1572,7 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
account, err = am.Store.GetAccountByUser(user.Id)
|
||||
@@ -1650,7 +1655,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
unlock := am.Store.AcquireAccountLock(newAcc.Id)
|
||||
unlock := am.Store.AcquireAccountWriteLock(newAcc.Id)
|
||||
alreadyUnlocked := false
|
||||
defer func() {
|
||||
if !alreadyUnlocked {
|
||||
@@ -1801,7 +1806,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
||||
|
||||
account, err := am.Store.GetAccountByUser(claims.UserId)
|
||||
if err == nil {
|
||||
unlockAccount := am.Store.AcquireAccountLock(account.Id)
|
||||
unlockAccount := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlockAccount()
|
||||
account, err = am.Store.GetAccountByUser(claims.UserId)
|
||||
if err != nil {
|
||||
@@ -1821,7 +1826,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
||||
return account, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
if domainAccount != nil {
|
||||
unlockAccount := am.Store.AcquireAccountLock(domainAccount.Id)
|
||||
unlockAccount := am.Store.AcquireAccountWriteLock(domainAccount.Id)
|
||||
defer unlockAccount()
|
||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(claims.Domain)
|
||||
if err != nil {
|
||||
@@ -1835,6 +1840,56 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountReadLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account)
|
||||
if err != nil {
|
||||
return nil, nil, mapError(err)
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(peerPubKey, true, realIP, account)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return peer, netMap, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(peer.Key, false, nil, account)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as connected %s %v", peer.Key, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
|
||||
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
|
||||
return am.peersUpdateManager.GetAllConnectedPeers(), nil
|
||||
|
||||
@@ -1655,7 +1655,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@@ -1666,7 +1666,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
@@ -1732,8 +1735,10 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
},
|
||||
}
|
||||
|
||||
account, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1745,7 +1750,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@@ -1756,7 +1761,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil)
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
@@ -35,7 +35,7 @@ func (d DNSSettings) Copy() DNSSettings {
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -57,7 +57,7 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string)
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
// GetEvents returns a list of activity events of an account
|
||||
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -279,8 +279,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
// AcquireAccountLock acquires account lock and returns a function that releases the lock
|
||||
func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
|
||||
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
log.Debugf("acquiring lock for account %s", accountID)
|
||||
start := time.Now()
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
@@ -295,6 +295,12 @@ func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
|
||||
// This method is still returns a write lock as file store can't handle read locks
|
||||
func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||
return s.AcquireAccountWriteLock(accountID)
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveAccount(account *Account) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
@@ -572,6 +578,18 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
return account.Copy(), nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
accountID, ok := s.PeerKeyID2AccountID[peerKey]
|
||||
if !ok {
|
||||
return "", status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
// GetInstallationID returns the installation ID from the store
|
||||
func (s *FileStore) GetInstallationID() string {
|
||||
return s.InstallationID
|
||||
|
||||
@@ -22,7 +22,7 @@ func (e *GroupLinkError) Error() string {
|
||||
|
||||
// GetGroup object of the peers
|
||||
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -76,7 +76,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -109,7 +109,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n
|
||||
|
||||
// SaveGroup object of the peers
|
||||
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -214,7 +214,7 @@ func difference(a, b []string) []string {
|
||||
|
||||
// DeleteGroup object of the peers
|
||||
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountId)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountId)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountId)
|
||||
@@ -323,7 +323,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -341,7 +341,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group,
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -134,9 +134,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
return err
|
||||
}
|
||||
|
||||
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
|
||||
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP)
|
||||
if err != nil {
|
||||
return mapError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(peerKey, peer, netMap, srv)
|
||||
@@ -149,11 +149,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
s.ephemeralManager.OnPeerConnected(peer)
|
||||
|
||||
err = s.accountManager.MarkPeerConnected(peerKey.String(), true, realIP)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
|
||||
}
|
||||
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
s.turnCredentialsManager.SetupRefresh(peer.ID)
|
||||
}
|
||||
@@ -207,7 +202,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(peer.ID)
|
||||
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.MarkPeerConnected(peer.Key, false, nil)
|
||||
_ = s.accountManager.CancelPeerRoutines(peer)
|
||||
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin
|
||||
return errors.New("invalid groups")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
a, err := am.Store.GetAccountByUser(userID)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -99,3 +101,104 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding.
|
||||
// T is the type of the model that contains the field to be migrated.
|
||||
func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, indexName string) error {
|
||||
oldColumnName := fieldName
|
||||
newColumnName := fieldName + "_tmp"
|
||||
|
||||
var model T
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.Printf("Table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(&model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
|
||||
var item sql.NullString
|
||||
if err := db.Model(&model).Select(oldColumnName).First(&item).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Printf("No records in table %s, no migration needed", tableName)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("fetch first record: %w", err)
|
||||
}
|
||||
|
||||
if item.Valid {
|
||||
var js json.RawMessage
|
||||
var syntaxError *json.SyntaxError
|
||||
err = json.Unmarshal([]byte(item.String), &js)
|
||||
if err == nil || !errors.As(err, &syntaxError) {
|
||||
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s TEXT", tableName, newColumnName)).Error; err != nil {
|
||||
return fmt.Errorf("add column %s: %w", newColumnName, err)
|
||||
}
|
||||
|
||||
var rows []map[string]any
|
||||
if err := tx.Table(tableName).Select("id", oldColumnName).Find(&rows).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Printf("No records in table %s, no migration needed", tableName)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("find rows: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
var blobValue string
|
||||
if columnValue := row[oldColumnName]; columnValue != nil {
|
||||
value, ok := columnValue.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("type assertion failed")
|
||||
}
|
||||
blobValue = value
|
||||
}
|
||||
|
||||
columnIpValue := net.IP(blobValue)
|
||||
if net.ParseIP(columnIpValue.String()) == nil {
|
||||
log.Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName)
|
||||
columnIpValue = net.IPv6loopback
|
||||
}
|
||||
|
||||
jsonValue, err := json.Marshal(columnIpValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("re-encode to JSON: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, jsonValue).Error; err != nil {
|
||||
return fmt.Errorf("update row: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if indexName != "" {
|
||||
if err := tx.Migrator().DropIndex(&model, indexName); err != nil {
|
||||
return fmt.Errorf("drop index %s: %w", indexName, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableName, oldColumnName)).Error; err != nil {
|
||||
return fmt.Errorf("drop column %s: %w", oldColumnName, err)
|
||||
}
|
||||
if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", tableName, newColumnName, oldColumnName)).Error; err != nil {
|
||||
return fmt.Errorf("rename column %s to %s: %w", newColumnName, oldColumnName, err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Migration of %s.%s from blob to json completed", tableName, fieldName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -89,3 +90,72 @@ func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) {
|
||||
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
|
||||
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged")
|
||||
}
|
||||
|
||||
func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
|
||||
require.NoError(t, err, "Migration should not fail for an empty database")
|
||||
}
|
||||
|
||||
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
type location struct {
|
||||
nbpeer.Location
|
||||
ConnectionIP net.IP
|
||||
}
|
||||
|
||||
type peer struct {
|
||||
nbpeer.Peer
|
||||
Location location `gorm:"embedded;embeddedPrefix:location_"`
|
||||
}
|
||||
|
||||
type account struct {
|
||||
server.Account
|
||||
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
|
||||
}
|
||||
|
||||
err = db.Save(&account{
|
||||
Account: server.Account{Id: "123"},
|
||||
Peers: []peer{
|
||||
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
}},
|
||||
).Error
|
||||
require.NoError(t, err, "Failed to insert blob data")
|
||||
|
||||
var blobValue string
|
||||
err = db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&blobValue).Error
|
||||
assert.NoError(t, err, "Failed to fetch blob data")
|
||||
|
||||
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||
require.NoError(t, err, "Migration should not fail with net.IP blob data")
|
||||
|
||||
var jsonStr string
|
||||
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
|
||||
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be migrated")
|
||||
}
|
||||
|
||||
func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
|
||||
err = db.Save(&server.Account{
|
||||
Id: "1234",
|
||||
PeersG: []nbpeer.Peer{
|
||||
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
}},
|
||||
).Error
|
||||
require.NoError(t, err, "Failed to insert JSON data")
|
||||
|
||||
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||
require.NoError(t, err, "Migration should not fail with net.IP JSON data")
|
||||
|
||||
var jsonStr string
|
||||
db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr)
|
||||
assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged")
|
||||
}
|
||||
|
||||
@@ -22,80 +22,93 @@ type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
ListUsersFunc func(accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
|
||||
DeletePeerFunc func(accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
||||
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
|
||||
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
|
||||
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
|
||||
SaveGroupFunc func(accountID, userID string, group *group.Group) error
|
||||
DeleteGroupFunc func(accountID, userId, groupID string) error
|
||||
ListGroupsFunc func(accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
|
||||
DeleteRuleFunc func(accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
|
||||
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
|
||||
DeletePolicyFunc func(accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
|
||||
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
MarkPATUsedFunc func(pat string) error
|
||||
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
|
||||
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(accountID, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(accountID, routeID, userID string) error
|
||||
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
||||
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||
DeleteAccountFunc func(accountID, userID string) error
|
||||
GetDNSDomainFunc func() string
|
||||
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
|
||||
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
ListUsersFunc func(accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
DeletePeerFunc func(accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
||||
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
|
||||
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
|
||||
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
|
||||
SaveGroupFunc func(accountID, userID string, group *group.Group) error
|
||||
DeleteGroupFunc func(accountID, userId, groupID string) error
|
||||
ListGroupsFunc func(accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
|
||||
DeleteRuleFunc func(accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
|
||||
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
|
||||
DeletePolicyFunc func(accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
|
||||
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
MarkPATUsedFunc func(pat string) error
|
||||
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
|
||||
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(accountID, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(accountID, routeID, userID string) error
|
||||
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
||||
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||
DeleteAccountFunc func(accountID, userID string) error
|
||||
GetDNSDomainFunc func() string
|
||||
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
|
||||
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
|
||||
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
|
||||
GroupValidationFunc func(accountId string, groups []string) (bool, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) {
|
||||
if am.SyncAndMarkPeerFunc != nil {
|
||||
return am.SyncAndMarkPeerFunc(peerPubKey, realIP)
|
||||
}
|
||||
return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
|
||||
approvedPeers := make(map[string]struct{})
|
||||
for id := range account.Peers {
|
||||
@@ -180,7 +193,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *server.Account) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(peerKey, connected, realIP)
|
||||
}
|
||||
@@ -626,7 +639,7 @@ func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *
|
||||
}
|
||||
|
||||
// SyncPeer mocks SyncPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) {
|
||||
func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) {
|
||||
if am.SyncPeerFunc != nil {
|
||||
return am.SyncPeerFunc(sync)
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||
func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -47,7 +47,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -94,7 +94,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
|
||||
// SaveNameServerGroup saves nameserver group
|
||||
func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if nsGroupToSave == nil {
|
||||
@@ -129,7 +129,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
|
||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -159,7 +159,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use
|
||||
// ListNameServerGroups returns a list of nameserver groups from account
|
||||
func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -88,21 +88,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
|
||||
}
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP) error {
|
||||
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error {
|
||||
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -156,7 +142,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
|
||||
func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -278,7 +264,7 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string,
|
||||
|
||||
// DeletePeer removes peer from the account by its IP
|
||||
func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -362,7 +348,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
@@ -381,7 +367,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
}
|
||||
|
||||
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
|
||||
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountLock (e.g., database is slow)
|
||||
// Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow)
|
||||
// and the peer disconnects with a timeout and tries to register again.
|
||||
// We just check if this machine has been registered before and reject the second registration.
|
||||
// The connecting peer should be able to recover with a retry.
|
||||
@@ -518,25 +504,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
}
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// we found the peer, and we follow a normal login flow
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||
@@ -603,7 +571,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
}
|
||||
|
||||
// we found the peer, and we follow a normal login flow
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
||||
@@ -760,7 +728,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
@@ -795,7 +763,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
|
||||
|
||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||
func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -13,13 +13,13 @@ type Peer struct {
|
||||
// ID is an internal ID of the peer
|
||||
ID string `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"`
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
// WireGuard public key
|
||||
Key string `gorm:"index"`
|
||||
// A setup key this peer was registered with
|
||||
SetupKey string
|
||||
// IP address of the Peer
|
||||
IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"`
|
||||
IP net.IP `gorm:"serializer:json"`
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
// Name is peer's name (machine name)
|
||||
@@ -61,7 +61,7 @@ type PeerStatus struct { //nolint:revive
|
||||
|
||||
// Location is a geo location information of a Peer based on public connection IP
|
||||
type Location struct {
|
||||
ConnectionIP net.IP // from grpc peer or reverse proxy headers depends on setup
|
||||
ConnectionIP net.IP `gorm:"serializer:json"` // from grpc peer or reverse proxy headers depends on setup
|
||||
CountryCode string
|
||||
CityName string
|
||||
GeoNameID uint // city level geoname id
|
||||
|
||||
@@ -314,7 +314,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
|
||||
|
||||
// GetPolicy from the store
|
||||
func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -342,7 +342,7 @@ func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (
|
||||
|
||||
// SavePolicy in the store
|
||||
func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -370,7 +370,7 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
|
||||
|
||||
// DeletePolicy from the store
|
||||
func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -397,7 +397,7 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string
|
||||
|
||||
// ListPolicies from the store
|
||||
func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -34,7 +34,7 @@ func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, us
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -81,7 +81,7 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -113,7 +113,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID,
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
// GetRoute gets a route object from account and route IDs
|
||||
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -115,7 +115,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account
|
||||
|
||||
// CreateRoute creates and saves a new route
|
||||
func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -194,7 +194,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string,
|
||||
|
||||
// SaveRoute saves route
|
||||
func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave *route.Route) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if routeToSave == nil {
|
||||
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -283,7 +283,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string)
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -209,7 +209,7 @@ func Hash(s string) uint32 {
|
||||
// and adds it to the specified account. A list of autoGroups IDs can be empty.
|
||||
func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
keyDuration := DefaultSetupKeyDuration
|
||||
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
|
||||
// (e.g. the key itself, creation date, ID, etc).
|
||||
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
||||
func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if keyToSave == nil {
|
||||
@@ -327,7 +327,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
|
||||
|
||||
// ListSetupKeys returns a list of all setup keys of the account
|
||||
func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
@@ -359,7 +359,7 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set
|
||||
|
||||
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
|
||||
func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
|
||||
@@ -127,17 +127,33 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
|
||||
log.Tracef("acquiring lock for account %s", accountID)
|
||||
func (s *SqliteStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
log.Tracef("acquiring write lock for account %s", accountID)
|
||||
|
||||
start := time.Now()
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
mtx := value.(*sync.Mutex)
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.Tracef("released lock for account %s in %v", accountID, time.Since(start))
|
||||
log.Tracef("released write lock for account %s in %v", accountID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (s *SqliteStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||
log.Tracef("acquiring read lock for account %s", accountID)
|
||||
|
||||
start := time.Now()
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.RLock()
|
||||
|
||||
unlock = func() {
|
||||
mtx.RUnlock()
|
||||
log.Tracef("released read lock for account %s in %v", accountID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
@@ -263,36 +279,43 @@ func (s *SqliteStore) GetInstallationID() string {
|
||||
}
|
||||
|
||||
func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
var peer nbpeer.Peer
|
||||
var peerCopy nbpeer.Peer
|
||||
peerCopy.Status = &peerStatus
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ? AND id = ?", accountID, peerID).
|
||||
Updates(peerCopy)
|
||||
|
||||
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "peer %s not found", peerID)
|
||||
}
|
||||
log.Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "issue getting peer from store")
|
||||
return result.Error
|
||||
}
|
||||
|
||||
peer.Status = &peerStatus
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "peer %s not found", peerID)
|
||||
}
|
||||
|
||||
return s.db.Save(peer).Error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
var peer nbpeer.Peer
|
||||
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerWithLocation.ID)
|
||||
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||
var peerCopy nbpeer.Peer
|
||||
// Since the location field has been migrated to JSON serialization,
|
||||
// updating the struct ensures the correct data format is inserted into the database.
|
||||
peerCopy.Location = peerWithLocation.Location
|
||||
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ? and id = ?", accountID, peerWithLocation.ID).
|
||||
Updates(peerCopy)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "peer %s not found", peer.ID)
|
||||
}
|
||||
log.Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "issue getting peer from store")
|
||||
return result.Error
|
||||
}
|
||||
|
||||
peer.Location = peerWithLocation.Location
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID)
|
||||
}
|
||||
|
||||
return s.db.Save(peer).Error
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteHashedPAT2TokenIDIndex is noop in Sqlite
|
||||
@@ -400,8 +423,9 @@ func (s *SqliteStore) GetAllAccounts() (all []*Account) {
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
|
||||
|
||||
var account Account
|
||||
result := s.db.Model(&account).
|
||||
result := s.db.Debug().Model(&account).
|
||||
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
|
||||
Preload(clause.Associations).
|
||||
First(&account, "id = ?", accountID)
|
||||
@@ -521,6 +545,21 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
return s.GetAccount(peer.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
var peer nbpeer.Peer
|
||||
var accountID string
|
||||
result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||
var user User
|
||||
@@ -571,13 +610,17 @@ func getMigrations() []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net")
|
||||
},
|
||||
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network")
|
||||
},
|
||||
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -12,6 +10,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -518,15 +519,29 @@ func TestMigrate(t *testing.T) {
|
||||
Net net.IPNet `gorm:"serializer:gob"`
|
||||
}
|
||||
|
||||
type location struct {
|
||||
nbpeer.Location
|
||||
ConnectionIP net.IP
|
||||
}
|
||||
|
||||
type peer struct {
|
||||
nbpeer.Peer
|
||||
Location location `gorm:"embedded;embeddedPrefix:location_"`
|
||||
}
|
||||
|
||||
type account struct {
|
||||
Account
|
||||
Network *network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
|
||||
}
|
||||
|
||||
act := &account{
|
||||
Network: &network{
|
||||
Net: *ipnet,
|
||||
},
|
||||
Peers: []peer{
|
||||
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
|
||||
},
|
||||
}
|
||||
|
||||
err = store.db.Save(act).Error
|
||||
|
||||
@@ -19,6 +19,7 @@ type Store interface {
|
||||
DeleteAccount(account *Account) error
|
||||
GetAccountByUser(userID string) (*Account, error)
|
||||
GetAccountByPeerPubKey(peerKey string) (*Account, error)
|
||||
GetAccountIDByPeerPubKey(peerKey string) (string, error)
|
||||
GetAccountByPeerID(peerID string) (*Account, error)
|
||||
GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(domain string) (*Account, error)
|
||||
@@ -29,8 +30,10 @@ type Store interface {
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
GetInstallationID() string
|
||||
SaveInstallationID(ID string) error
|
||||
// AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock
|
||||
AcquireAccountLock(accountID string) func()
|
||||
// AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock
|
||||
AcquireAccountWriteLock(accountID string) func()
|
||||
// AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock
|
||||
AcquireAccountReadLock(accountID string) func()
|
||||
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
||||
AcquireGlobalLock() func()
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
|
||||
@@ -210,7 +210,7 @@ func NewOwnerUser(id string) *User {
|
||||
|
||||
// createServiceUser creates a new service user under the given account.
|
||||
func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -266,7 +266,7 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, user *User
|
||||
|
||||
// inviteNewUser Invites a USer to a given account and creates reference in datastore
|
||||
func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
@@ -367,7 +367,7 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
|
||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
@@ -400,7 +400,7 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (
|
||||
// ListUsers returns lists of all users under the account.
|
||||
// It doesn't populate user information such as email or name.
|
||||
func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -427,7 +427,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t
|
||||
if initiatorUserID == targetUserID {
|
||||
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
|
||||
}
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -537,7 +537,7 @@ func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetU
|
||||
|
||||
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
|
||||
func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
@@ -577,7 +577,7 @@ func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID st
|
||||
|
||||
// CreatePAT creates a new PAT for the given user
|
||||
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if tokenName == "" {
|
||||
@@ -627,7 +627,7 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID str
|
||||
|
||||
// DeletePAT deletes a specific PAT from a user
|
||||
func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -677,7 +677,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID str
|
||||
|
||||
// GetPAT returns a specific PAT from a user
|
||||
func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -709,7 +709,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string
|
||||
|
||||
// GetAllPATs returns all PATs for a user
|
||||
func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
@@ -747,7 +747,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
|
||||
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
|
||||
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
|
||||
func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if update == nil {
|
||||
|
||||
Reference in New Issue
Block a user