Compare commits

...

3 Commits

Author SHA1 Message Date
Maycon Santos
5f5d597c59 fix GetSetupKeyByID 2025-06-26 13:54:04 +02:00
Maycon Santos
b85aad07d4 add getTXWithLockStrength method 2025-06-26 13:44:00 +02:00
Maycon Santos
7398836c2e Stop using locking share for read calls to avoid deadlocks
Added peer.userID index
2025-06-26 12:24:42 +02:00
29 changed files with 295 additions and 493 deletions

View File

@@ -370,7 +370,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil {
return err
}
@@ -708,7 +708,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
// AccountExists checks if an account exists.
func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID)
return am.Store.AccountExists(ctx, store.LockingStrengthNone, accountID)
}
// GetAccountIDByUserID retrieves the account ID based on the userID provided.
@@ -720,7 +720,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "no valid userID provided")
}
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
@@ -775,7 +775,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil {
return nil, nil, err
}
@@ -829,7 +829,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -859,7 +859,7 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
return nil, err
@@ -1010,7 +1010,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlockAccount()
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID)
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return err
@@ -1020,7 +1020,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
return nil
}
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting user: %v", err)
return err
@@ -1185,7 +1185,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
}
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
@@ -1205,7 +1205,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return "", "", err
}
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
// this is not really possible because we got an account by user ID
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
@@ -1237,7 +1237,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
return err
}
@@ -1263,12 +1263,12 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
var hasChanges bool
var user *types.User
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
@@ -1298,7 +1298,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
@@ -1308,7 +1308,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
groupsMap[group.ID] = group
}
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId)
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, userAuth.AccountId, userAuth.UserId)
if err != nil {
return fmt.Errorf("error getting user peers: %w", err)
}
@@ -1340,7 +1340,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
}
for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else {
@@ -1353,7 +1353,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
}
for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else {
@@ -1414,7 +1414,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
}
if userAuth.IsChild {
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId)
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil || !exists {
return "", err
}
@@ -1438,7 +1438,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err
}
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
@@ -1459,7 +1459,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
}
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1474,7 +1474,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
cancel := am.Store.AcquireGlobalLock(ctx)
// check again if the domain has a primary account because of simultaneous requests
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil {
cancel()
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1485,7 +1485,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
}
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
@@ -1495,7 +1495,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
}
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId)
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, userAuth.AccountId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return "", err
@@ -1506,7 +1506,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
}
// We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain)
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, userAuth.Domain)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", err
@@ -1636,7 +1636,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
}
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID)
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil {
return false, err
}
@@ -1658,7 +1658,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
}
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err)
}
@@ -1684,7 +1684,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
}
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
@@ -1770,7 +1770,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
cancel := am.Store.AcquireGlobalLock(ctx)
defer cancel()
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil {
return nil, false, err
}
@@ -1790,7 +1790,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
for range 2 {
accountId := xid.New().String()
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId)
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, accountId)
if err != nil || exists {
continue
}
@@ -1865,7 +1865,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return nil
}
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain)
// error is not a not found error
if handleNotFound(err) != nil {
@@ -1902,7 +1902,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
// Returns true if any groups were modified, true if those updates affect peers and an error.
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, false, err
}
@@ -1912,7 +1912,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
groupsMap[group.ID] = group
}
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, false, err
}
@@ -1920,7 +1920,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
groupsToUpdate := make(map[string]*types.Group)
for _, user := range users {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id)
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id)
if err != nil {
return false, false, err
}

View File

@@ -782,7 +782,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
return
}
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID)
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID)
assert.NoError(t, err)
assert.True(t, exists, "expected to get existing account after creation using userid")
@@ -899,11 +899,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
}
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1")
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1")
require.NoError(t, err)
assert.Len(t, pats, 0)
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId)
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId)
require.NoError(t, err)
assert.Len(t, pats, 0)
}
@@ -1775,7 +1775,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings")
assert.NotNil(t, settings)
@@ -1960,7 +1960,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled)
@@ -2643,7 +2643,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
})
@@ -2657,7 +2657,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
})
@@ -2671,11 +2671,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
})
@@ -2692,11 +2692,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
})
@@ -2710,7 +2710,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})
@@ -2724,7 +2724,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})
@@ -2738,11 +2738,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added")
})
@@ -2756,7 +2756,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
@@ -2771,7 +2771,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
})
@@ -3354,7 +3354,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group1.ID)
@@ -3365,7 +3365,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.True(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID)
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID)
require.NoError(t, err)
assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1")
@@ -3376,7 +3376,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group2.ID)
@@ -3403,7 +3403,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.True(t, groupsUpdated)
assert.True(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err)
for _, group := range groups {
assert.Len(t, group.Peers, 2)
@@ -3420,7 +3420,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
})
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err)
user.AutoGroups = []string{"group1"}
@@ -3431,7 +3431,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err)
for _, group := range groups {
assert.Len(t, group.Peers, 2)

View File

@@ -73,7 +73,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
return userAuth, nil
}
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
return userAuth, err
}
@@ -104,7 +104,7 @@ func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.Us
return nil, nil, "", "", err
}
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, user.AccountID)
if err != nil {
return nil, nil, "", "", err
}
@@ -142,12 +142,12 @@ func (am *manager) extractPATFromToken(ctx context.Context, token string) (*type
var pat *types.PersonalAccessToken
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthNone, encodedHashedToken)
if err != nil {
return err
}
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthNone, pat.ID)
return err
})
if err != nil {

View File

@@ -72,7 +72,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
}
// SaveDNSSettings validates a user role and updates the account's DNS settings
@@ -139,7 +139,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil
@@ -195,7 +195,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
return nil
}
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, settings.DisabledManagementGroups)
if err != nil {
return err
}

View File

@@ -122,7 +122,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
}
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare)
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return

View File

@@ -103,7 +103,7 @@ func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, eve
}
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountId)
if err != nil {
return nil, err
}
@@ -154,7 +154,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
continue
}
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
if err != nil {
// @todo consider logging
continue

View File

@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
return am.Store.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
}
// GetAllGroups returns all groups in an account
@@ -57,12 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName)
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}
// SaveGroup object of the peers
@@ -140,7 +140,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
@@ -152,13 +152,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
}
modifiedPeers := slices.Concat(addedPeers, removedPeers)
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers)
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, modifiedPeers)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil
}
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
return nil
@@ -431,7 +431,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
}
if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name)
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthNone, accountID, newGroup.Name)
if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err
@@ -448,7 +448,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
}
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
@@ -460,7 +460,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to get user")
}
@@ -506,7 +506,7 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID)
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil {
return status.Errorf(status.Internal, "failed to get DNS settings")
}
@@ -515,7 +515,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil {
return status.Errorf(status.Internal, "failed to get account settings")
}
@@ -529,7 +529,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
@@ -549,7 +549,7 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
@@ -567,7 +567,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, account
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
@@ -586,7 +586,7 @@ func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
@@ -602,7 +602,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accou
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
@@ -618,7 +618,7 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
return false, nil
@@ -638,7 +638,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
@@ -666,7 +666,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
if err != nil {
return false, err
}

View File

@@ -49,7 +49,7 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string
return nil, err
}
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err)
}
@@ -96,13 +96,13 @@ func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, trans
return nil, fmt.Errorf("error adding resource to group: %w", err)
}
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
return nil, fmt.Errorf("error getting group: %w", err)
}
// TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID)
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resource.ID)
if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err)
}
@@ -120,13 +120,13 @@ func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context,
return nil, fmt.Errorf("error removing resource from group: %w", err)
}
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
return nil, fmt.Errorf("error getting group: %w", err)
}
// TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID)
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err)
}

View File

@@ -63,7 +63,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs {
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID)
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthNone, accountID, groupID)
if err != nil {
return err
}
@@ -83,17 +83,17 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
var peers []*nbpeer.Peer
var settings *types.Settings
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil {
return nil, err
}
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}

View File

@@ -641,7 +641,7 @@ func testSyncStatusRace(t *testing.T) {
}
time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerWithInvalidStatus.PublicKey().String())
if err != nil {
t.Fatal(err)
return

View File

@@ -32,7 +32,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID)
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupID)
}
// CreateNameServerGroup creates and saves a new nameserver group
@@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID)
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
if err != nil {
return err
}
@@ -202,7 +202,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
}
func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
@@ -216,7 +216,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return err
}
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
@@ -226,7 +226,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return err
}
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, nameserverGroup.Groups)
if err != nil {
return err
}

View File

@@ -56,7 +56,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri
return nil, status.NewPermissionDeniedError()
}
return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID)
return m.store.GetAccountNetworks(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
@@ -92,7 +92,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network
return nil, status.NewPermissionDeniedError()
}
return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID)
return m.store.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
}
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {

View File

@@ -57,7 +57,7 @@ func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError()
}
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID)
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
}
func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
@@ -69,7 +69,7 @@ func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError()
}
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID)
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
@@ -81,7 +81,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID,
return nil, status.NewPermissionDeniedError()
}
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID)
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get network resources: %w", err)
}
@@ -113,7 +113,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name)
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil {
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
}
@@ -174,7 +174,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ
return nil, status.NewPermissionDeniedError()
}
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID)
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
return nil, fmt.Errorf("failed to get network resource: %w", err)
}
@@ -218,17 +218,17 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID)
}
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID)
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil {
return fmt.Errorf("failed to get network resource: %w", err)
}
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name)
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil && oldResource.ID != resource.ID {
return status.Errorf(status.InvalidArgument, "new resource name already exists")
}
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID)
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil {
return fmt.Errorf("failed to get network resource: %w", err)
}

View File

@@ -54,7 +54,7 @@ func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError()
}
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID)
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
}
func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
@@ -66,7 +66,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError()
}
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get network routers: %w", err)
}
@@ -93,7 +93,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID)
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
@@ -136,7 +136,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI
return nil, status.NewPermissionDeniedError()
}
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID)
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID)
if err != nil {
return nil, fmt.Errorf("failed to get network router: %w", err)
}
@@ -162,7 +162,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID)
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
@@ -232,7 +232,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
}
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID)
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err)
}

View File

@@ -35,7 +35,7 @@ import (
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return nil, err
}
@@ -45,7 +45,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, status.NewPermissionValidationError(err)
}
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, nameFilter, ipFilter)
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter)
if err != nil {
return nil, err
}
@@ -55,7 +55,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return accountPeers, nil
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get account settings: %w", err)
}
@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if peer.AddedWithSSOLogin() {
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
@@ -216,7 +216,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return err
}
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
@@ -335,7 +335,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return status.NewPermissionDeniedError()
}
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID)
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil {
return err
}
@@ -468,7 +468,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
addedByUser := false
if len(userID) > 0 {
addedByUser = true
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
} else {
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
}
@@ -488,7 +488,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key)
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
@@ -584,7 +584,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
ExtraDNSLabels: peer.ExtraDNSLabels,
AllowExtraDNSLabels: allowExtraDNSLabels,
}
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
@@ -674,7 +674,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) {
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID)
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
}
@@ -706,7 +706,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
var err error
var postureChecks []*posture.Checks
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -718,7 +718,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
if peer.UserID != "" {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID)
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil {
return err
}
@@ -821,7 +821,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
var isPeerUpdated bool
var postureChecks []*posture.Checks
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -906,7 +906,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
// getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -930,7 +930,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
}
peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs)
peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, peerPostureChecksIDs)
if err != nil {
return nil, err
}
@@ -945,7 +945,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
continue
}
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, rule.Sources)
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
if err != nil {
return nil, err
}
@@ -970,7 +970,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, login.WireGuardPubKey)
if err != nil {
return err
}
@@ -981,7 +981,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
@@ -1000,7 +1000,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
}()
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -1062,7 +1062,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
}
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, peer.AccountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
@@ -1104,7 +1104,7 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se
// GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, err
}
@@ -1117,7 +1117,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return peer, nil
}
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return nil, err
}
@@ -1143,7 +1143,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
// it is also possible that user doesn't own the peer but some of his peers have access to it,
// this is a valid case, show the peer as well.
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID)
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
if err != nil {
return nil, err
}
@@ -1328,7 +1328,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected.
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID)
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
return peerSchedulerRetryInterval, true
@@ -1338,7 +1338,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return peerSchedulerRetryInterval, true
@@ -1372,7 +1372,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected.
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID)
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
return peerSchedulerRetryInterval, true
@@ -1382,7 +1382,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return peerSchedulerRetryInterval, true
@@ -1413,12 +1413,12 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
// getExpiredPeers returns peers that have been expired.
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID)
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -1436,12 +1436,12 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
// getInactivePeers returns peers that have been expired by inactivity
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID)
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -1459,12 +1459,12 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
// GetPeerGroups returns groups that the peer is part of.
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
return am.Store.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID)
return am.Store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
}
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) {
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID)
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return nil, err
}
@@ -1478,7 +1478,7 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
}
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) {
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID)
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -1505,7 +1505,7 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
var peerDeletedEvents []func()
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -1516,7 +1516,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
return nil, err
}
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -1577,7 +1577,7 @@ func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transac
// isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account.
func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID)
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err)
return false, nil

View File

@@ -31,7 +31,7 @@ type Peer struct {
// Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
// The user ID that registered the peer
UserID string
UserID string `gorm:"index"`
// SSHKey is a public SSH key of the peer
SSHKey string
// SSHEnabled indicates whether SSH server is enabled on the peer

View File

@@ -1301,7 +1301,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key)
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, addedPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID)
@@ -1423,7 +1423,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
assert.NotNil(t, addedPeer, "addedPeer should not be nil on success")
assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch")
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, currentPeer.Key)
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, currentPeer.Key)
require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key)
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store")
@@ -1505,7 +1505,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err)
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key)
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key)
require.Error(t, err)
account, err := s.GetAccount(context.Background(), existingAccountID)
@@ -1671,7 +1671,7 @@ func Test_LoginPeer(t *testing.T) {
assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer")
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, loginInput.WireGuardPubKey)
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, loginInput.WireGuardPubKey)
require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey)
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore")

View File

@@ -42,7 +42,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
return nil, status.NewPermissionDeniedError()
}
return m.store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
@@ -52,12 +52,12 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
}
if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID)
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
return m.store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
}
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID)
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
}

View File

@@ -45,7 +45,7 @@ func (m *managerImpl) ValidateUserPermissions(
return true, nil
}
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return false, err
}

View File

@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID)
return am.Store.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
}
// SavePolicy in the store
@@ -142,13 +142,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
}
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
if isUpdate {
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil {
return false, err
}
@@ -173,7 +173,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
// validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
if policy.ID != "" {
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil {
return err
}
@@ -182,12 +182,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
policy.AccountID = accountID
}
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups())
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
if err != nil {
return err
}
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks)
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
if err != nil {
return err
}

View File

@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
}
// SavePostureChecks saves a posture check.
@@ -101,7 +101,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
var postureChecks *posture.Checks
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
if err != nil {
return err
}
@@ -135,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
}
// getPeerPostureChecks returns the posture checks applied for a given peer.
@@ -161,7 +161,7 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, pe
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
@@ -190,14 +190,14 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
// If the posture check already has an ID, verify its existence in the store.
if postureChecks.ID != "" {
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil {
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID); err != nil {
return err
}
return nil
}
// For new posture checks, ensure no duplicates by name.
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
@@ -259,7 +259,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}

View File

@@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID))
return am.Store.GetRouteByID(ctx, store.LockingStrengthNone, accountID, string(routeID))
}
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
@@ -59,7 +59,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
seenPeers[string(prefixRoute.ID)] = true
}
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups)
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, prefixRoute.PeerGroups)
if err != nil {
return err
}
@@ -83,7 +83,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
if peerID := checkRoute.Peer; peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID)
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthNone, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
@@ -104,7 +104,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
}
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers)
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, group.Peers)
if err != nil {
return err
}
@@ -310,7 +310,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
}
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
@@ -353,7 +353,7 @@ func validateRoute(ctx context.Context, transaction store.Store, accountID strin
// validateRouteGroups validates the route groups and returns the validated groups map.
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate)
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupsToValidate)
if err != nil {
return nil, err
}
@@ -494,7 +494,7 @@ func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, ro
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}

View File

@@ -1100,7 +1100,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id)
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, account.Id)
require.NoError(t, err)
var groupHA1, groupHA2 *types.Group
for _, group := range groups {

View File

@@ -60,7 +60,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
return nil, fmt.Errorf("get extra settings: %w", err)
}
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("get account settings: %w", err)
}
@@ -82,7 +82,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
return nil, fmt.Errorf("get extra settings: %w", err)
}
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("get account settings: %w", err)
}

View File

@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
}
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id)
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyToSave.Id)
if err != nil {
return err
}
@@ -175,7 +175,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
}
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
@@ -188,7 +188,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError()
}
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
if err != nil {
return nil, err
}
@@ -214,7 +214,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
var deletedSetupKey *types.SetupKey
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
if err != nil {
return err
}
@@ -231,7 +231,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
}
func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, autoGroupIDs)
if err != nil {
return err
}
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil

View File

@@ -478,7 +478,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
}
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) {
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthNone, domain)
if err != nil {
return nil, err
}
@@ -488,11 +488,7 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
}
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountID string
result := tx.Model(&types.Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
@@ -542,11 +538,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
}
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var user types.User
result := tx.
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
@@ -563,11 +555,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var user types.User
result := tx.First(&user, idQueryCondition, userID)
if result.Error != nil {
@@ -600,11 +588,7 @@ func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength,
}
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var users []*types.User
result := tx.Find(&users, accountIDCondition, accountID)
if result.Error != nil {
@@ -619,11 +603,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
}
func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var user types.User
result := tx.First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
if result.Error != nil {
@@ -637,11 +617,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var groups []*types.Group
result := tx.Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
@@ -656,11 +632,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
}
func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var groups []*types.Group
likePattern := `%"ID":"` + resourceID + `"%`
@@ -706,11 +678,7 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
}
func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountMeta types.AccountMeta
result := tx.Model(&types.Account{}).
First(&accountMeta, idQueryCondition, accountID)
@@ -880,11 +848,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
}
func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountID string
result := tx.Model(&types.User{}).
Select("account_id").Where(idQueryCondition, userID).First(&accountID)
@@ -899,11 +863,7 @@ func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength Lockin
}
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountID string
result := tx.Model(&nbpeer.Peer{}).
Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
@@ -936,11 +896,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
}
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
@@ -968,11 +924,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var labels []string
result := tx.Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
@@ -990,11 +942,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountNetwork types.AccountNetwork
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -1006,11 +954,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var peer nbpeer.Peer
result := tx.First(&peer, GetKeyQueryCondition(s), peerKey)
@@ -1025,11 +969,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountSettings types.AccountSettings
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -1041,11 +981,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
}
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var createdBy string
result := tx.Model(&types.Account{}).
Select("created_by").First(&createdBy, idQueryCondition, accountID)
@@ -1243,11 +1179,7 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var setupKey types.SetupKey
result := tx.
First(&setupKey, GetKeyQueryCondition(s), key)
@@ -1391,11 +1323,7 @@ func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string
// GetPeerGroups retrieves all groups assigned to a specific peer in a given account.
func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var groups []*types.Group
query := tx.
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
@@ -1409,8 +1337,9 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
// GetAccountPeers retrieves peers for an account.
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
tx := s.getTXWithLockStrength(lockStrength)
var peers []*nbpeer.Peer
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountIDCondition, accountID)
query := tx.Where(accountIDCondition, accountID)
if nameFilter != "" {
query = query.Where("name LIKE ?", "%"+nameFilter+"%")
@@ -1429,11 +1358,7 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var peers []*nbpeer.Peer
// Exclude peers added via setup keys, as they are not user-specific and have an empty user_id.
@@ -1461,11 +1386,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStr
// GetPeerByID retrieves a peer by its ID and account ID.
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var peer *nbpeer.Peer
result := tx.
First(&peer, accountAndIDQueryCondition, accountID, peerID)
@@ -1481,11 +1402,7 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
// GetPeersByIDs retrieves peers by their IDs and account ID.
func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var peers []*nbpeer.Peer
result := tx.Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs)
if result.Error != nil {
@@ -1503,11 +1420,7 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng
// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var peers []*nbpeer.Peer
result := tx.
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
@@ -1522,11 +1435,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var peers []*nbpeer.Peer
result := tx.
Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
@@ -1541,11 +1450,7 @@ func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStreng
// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing.
func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var allEphemeralPeers, batchPeers []*nbpeer.Peer
result := tx.
Where("ephemeral = ?", true).
@@ -1623,11 +1528,7 @@ func (s *SqlStore) GetDB() *gorm.DB {
}
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountDNSSettings types.AccountDNSSettings
result := tx.Model(&types.Account{}).
First(&accountDNSSettings, idQueryCondition, accountID)
@@ -1643,11 +1544,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var accountID string
result := tx.Model(&types.Account{}).
Select("id").First(&accountID, idQueryCondition, id)
@@ -1663,11 +1560,7 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var account types.Account
result := tx.Model(&types.Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account)
@@ -1683,11 +1576,7 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
// GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var group *types.Group
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
@@ -1703,11 +1592,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
// GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var group types.Group
// TODO: This fix is accepted for now, but if we need to handle this more frequently
@@ -1736,11 +1621,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// GetGroupsByIDs retrieves groups by their IDs and account ID.
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var groups []*types.Group
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
@@ -1796,11 +1677,7 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var policies []*types.Policy
result := tx.
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
@@ -1814,11 +1691,7 @@ func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingS
// GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var policy *types.Policy
result := tx.Preload(clause.Associations).
@@ -1880,11 +1753,7 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var postureChecks []*posture.Checks
result := tx.Find(&postureChecks, accountIDCondition, accountID)
if result.Error != nil {
@@ -1897,10 +1766,7 @@ func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength Loc
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var postureCheck *posture.Checks
result := tx.
@@ -1918,11 +1784,7 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var postureChecks []*posture.Checks
result := tx.Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
if result.Error != nil {
@@ -1967,9 +1829,9 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
// GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
tx := s.getTXWithLockStrength(lockStrength)
var routes []*route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&routes, accountIDCondition, accountID)
result := tx.Find(&routes, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get routes from store")
@@ -1980,9 +1842,9 @@ func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStr
// GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
tx := s.getTXWithLockStrength(lockStrength)
var route *route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&route, accountAndIDQueryCondition, accountID, routeID)
result := tx.First(&route, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewRouteNotFoundError(routeID)
@@ -2023,11 +1885,7 @@ func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var setupKeys []*types.SetupKey
result := tx.
Find(&setupKeys, accountIDCondition, accountID)
@@ -2041,14 +1899,9 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var setupKey *types.SetupKey
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
result := tx.First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewSetupKeyNotFoundError(setupKeyID)
@@ -2088,11 +1941,7 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var nsGroups []*nbdns.NameServerGroup
result := tx.Find(&nsGroups, accountIDCondition, accountID)
if err := result.Error; err != nil {
@@ -2105,11 +1954,7 @@ func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var nsGroup *nbdns.NameServerGroup
result := tx.
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
@@ -2182,11 +2027,7 @@ func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var networks []*networkTypes.Network
result := tx.Find(&networks, accountIDCondition, accountID)
if result.Error != nil {
@@ -2198,11 +2039,7 @@ func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingS
}
func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var network *networkTypes.Network
result := tx.
First(&network, accountAndIDQueryCondition, accountID, networkID)
@@ -2244,11 +2081,7 @@ func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStreng
}
func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netRouters []*routerTypes.NetworkRouter
result := tx.
Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID)
@@ -2261,11 +2094,7 @@ func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength Lo
}
func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netRouters []*routerTypes.NetworkRouter
result := tx.
Find(&netRouters, accountIDCondition, accountID)
@@ -2278,11 +2107,7 @@ func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrengt
}
func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netRouter *routerTypes.NetworkRouter
result := tx.
First(&netRouter, accountAndIDQueryCondition, accountID, routerID)
@@ -2323,11 +2148,7 @@ func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netResources []*resourceTypes.NetworkResource
result := tx.
Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID)
@@ -2340,11 +2161,7 @@ func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength
}
func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netResources []*resourceTypes.NetworkResource
result := tx.
Find(&netResources, accountIDCondition, accountID)
@@ -2357,11 +2174,7 @@ func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStren
}
func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netResources *resourceTypes.NetworkResource
result := tx.
First(&netResources, accountAndIDQueryCondition, accountID, resourceID)
@@ -2377,11 +2190,7 @@ func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength Lock
}
func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var netResources *resourceTypes.NetworkResource
result := tx.
First(&netResources, "account_id = ? AND name = ?", accountID, resourceName)
@@ -2423,10 +2232,7 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var pat types.PersonalAccessToken
result := tx.First(&pat, "hashed_token = ?", hashedToken)
@@ -2443,11 +2249,7 @@ func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength Locking
// GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var pat types.PersonalAccessToken
result := tx.
First(&pat, "id = ? AND user_id = ?", patID, userID)
@@ -2464,11 +2266,7 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
// GetUserPATs retrieves personal access tokens for a user.
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
var pats []*types.PersonalAccessToken
result := tx.Find(&pats, "user_id = ?", userID)
if err := result.Error; err != nil {
@@ -2528,11 +2326,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength,
}
func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
tx := s.getTXWithLockStrength(lockStrength)
jsonValue := fmt.Sprintf(`"%s"`, ip.String())
var peer nbpeer.Peer
@@ -2559,3 +2353,11 @@ func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain stri
return count, nil
}
func (s *SqlStore) getTXWithLockStrength(lockStrength LockingStrength) *gorm.DB {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
return tx
}

View File

@@ -95,14 +95,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError()
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return nil, err
}
inviterID := userID
if initiatorUser.IsServiceUser {
createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthShare, accountID)
createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -178,13 +178,13 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
}
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
}
// GetUser looks up a user by provided nbContext.UserAuths.
// Expects account to have been created already.
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
return nil, err
}
@@ -209,7 +209,7 @@ func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAu
// ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
return am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
return am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
}
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error {
@@ -230,7 +230,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return err
}
@@ -243,7 +243,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return status.NewPermissionDeniedError()
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return err
}
@@ -347,12 +347,12 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.NewPermissionDeniedError()
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return nil, err
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return nil, err
}
@@ -390,12 +390,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return status.NewPermissionDeniedError()
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return err
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return err
}
@@ -404,7 +404,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return status.NewAdminPermissionError()
}
pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID)
pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
if err != nil {
return err
}
@@ -429,12 +429,12 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, status.NewPermissionDeniedError()
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return nil, err
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return nil, err
}
@@ -443,7 +443,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, status.NewAdminPermissionError()
}
return am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID)
return am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
}
// GetAllPATs returns all PATs for a user
@@ -456,12 +456,12 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.NewPermissionDeniedError()
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return nil, err
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return nil, err
}
@@ -470,7 +470,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.NewAdminPermissionError()
}
return am.Store.GetUserPATs(ctx, store.LockingStrengthShare, targetUserID)
return am.Store.GetUserPATs(ctx, store.LockingStrengthNone, targetUserID)
}
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
@@ -511,7 +511,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
if !allowed {
return nil, status.NewPermissionDeniedError()
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -521,7 +521,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var addUserEvents []func()
var usersToSave = make([]*types.User, 0, len(updates))
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err)
}
@@ -533,7 +533,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var initiatorUser *types.User
if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return nil, err
}
@@ -695,7 +695,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id)
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id)
if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
if !addIfNotExists {
@@ -830,7 +830,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
var user *types.User
if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
@@ -840,7 +840,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
accountUsers := []*types.User{}
switch {
case allowed:
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -933,7 +933,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
@@ -1003,7 +1003,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
return status.NewPermissionDeniedError()
}
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return err
}
@@ -1017,7 +1017,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
continue
}
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
@@ -1081,12 +1081,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserInfo.ID)
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserInfo.ID)
if err != nil {
return fmt.Errorf("failed to get user to delete: %w", err)
}
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, targetUserInfo.ID)
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
if err != nil {
return fmt.Errorf("failed to get user peers: %w", err)
}
@@ -1120,7 +1120,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
// GetOwnerInfo retrieves the owner information for a given account ID.
func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) {
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthShare, accountID)
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
@@ -1257,7 +1257,7 @@ func validateUserInvite(invite *types.UserInfo) error {
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
accountID, userID := userAuth.AccountId, userAuth.UserId
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return nil, err
}
@@ -1274,7 +1274,7 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}

View File

@@ -88,7 +88,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
assert.Equal(t, pat.ID, tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthNone, tokenID)
if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err)
}
@@ -1521,7 +1521,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
_, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true)
assert.Error(t, err, "update user to another account should fail")
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthShare, targetId)
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetId)
require.NoError(t, err)
assert.Equal(t, account1.Users[targetId].Id, user.Id)
assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID)

View File

@@ -26,7 +26,7 @@ func NewManager(store store.Store) Manager {
}
func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
}
func NewManagerMock() Manager {