mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 14:44:34 -04:00
Compare commits
78 Commits
debug-api
...
refactor/g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a23a09bba3 | ||
|
|
2f7027194b | ||
|
|
197d844a16 | ||
|
|
df6c9a528a | ||
|
|
9cb7336ef5 | ||
|
|
e513e51e9f | ||
|
|
4ad00e784c | ||
|
|
bfeb7f0875 | ||
|
|
dde01b8e02 | ||
|
|
74246d18ba | ||
|
|
fa5db7d7ee | ||
|
|
fed48de83f | ||
|
|
e73b5da42b | ||
|
|
8cacdae70c | ||
|
|
6b94f6e4e7 | ||
|
|
b7525d9fe8 | ||
|
|
901d283114 | ||
|
|
7278a21b0d | ||
|
|
9bf0bf4843 | ||
|
|
313e158e20 | ||
|
|
0bdcb41e20 | ||
|
|
97dbdd7940 | ||
|
|
a82b5ce80e | ||
|
|
83be99c849 | ||
|
|
ee96a81b83 | ||
|
|
b0edc5f1f7 | ||
|
|
408d0cd504 | ||
|
|
b66f331711 | ||
|
|
d7a6996bed | ||
|
|
d7c63d5c04 | ||
|
|
1123729c1c | ||
|
|
a8c8b77df8 | ||
|
|
0297b5f142 | ||
|
|
78e238646c | ||
|
|
f9ed25f8b1 | ||
|
|
f43a006c34 | ||
|
|
1a37b12d1b | ||
|
|
d36d30dec4 | ||
|
|
43eb7261e3 | ||
|
|
9e47c94a7f | ||
|
|
edf67672ad | ||
|
|
bc520412ba | ||
|
|
d87fe0257b | ||
|
|
b1b2b0adf0 | ||
|
|
96f18c2c8c | ||
|
|
73be8c8a32 | ||
|
|
f61c914fd7 | ||
|
|
4575ae2841 | ||
|
|
ca6a9fd602 | ||
|
|
871595d15f | ||
|
|
30253b0565 | ||
|
|
dc82c2d1ce | ||
|
|
3b4bcdf5a4 | ||
|
|
87c8430e99 | ||
|
|
c384874d7d | ||
|
|
b815393180 | ||
|
|
41b212f610 | ||
|
|
16174f0478 | ||
|
|
d14b855670 | ||
|
|
eab85644cd | ||
|
|
7561706627 | ||
|
|
1ffe89d20d | ||
|
|
28840383e1 | ||
|
|
d9f612d623 | ||
|
|
7601a17150 | ||
|
|
8f98adddf6 | ||
|
|
26dd045da5 | ||
|
|
4d9bb7ea35 | ||
|
|
9631cb4fb3 | ||
|
|
8f9c54f6c2 | ||
|
|
f60a4234b1 | ||
|
|
021fc8f33e | ||
|
|
a4c4158bcf | ||
|
|
720d36a290 | ||
|
|
ccab3b427f | ||
|
|
e5d55d3c10 | ||
|
|
3cf1b02f31 | ||
|
|
258b30cf48 |
@@ -38,7 +38,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -171,7 +170,7 @@ type Engine struct {
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
srWatcher *guard.SRWatcher
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
|
||||
3
client/testdata/store.sql
vendored
3
client/testdata/store.sql
vendored
@@ -28,9 +28,12 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun
|
||||
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
|
||||
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
|
||||
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
COMMIT;
|
||||
|
||||
@@ -136,6 +136,7 @@ func ParseNameServerURL(nsURL string) (NameServer, error) {
|
||||
func (g *NameServerGroup) Copy() *NameServerGroup {
|
||||
nsGroup := &NameServerGroup{
|
||||
ID: g.ID,
|
||||
AccountID: g.AccountID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
NameServers: make([]NameServer, len(g.NameServers)),
|
||||
@@ -156,6 +157,7 @@ func (g *NameServerGroup) Copy() *NameServerGroup {
|
||||
// IsEqual compares one nameserver group with the other
|
||||
func (g *NameServerGroup) IsEqual(other *NameServerGroup) bool {
|
||||
return other.ID == g.ID &&
|
||||
other.AccountID == g.AccountID &&
|
||||
other.Name == g.Name &&
|
||||
other.Description == g.Description &&
|
||||
other.Primary == g.Primary &&
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -397,7 +398,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
|
||||
store := newStore(t)
|
||||
|
||||
err := newAccountWithId(context.Background(), store, "account-1", userID, "netbird.io")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
account, err := store.GetAccount(context.Background(), "account-1")
|
||||
require.NoError(t, err, "failed to get account")
|
||||
|
||||
account.UpdateSettings(&testCase.accountSettings)
|
||||
account.Network = network
|
||||
account.Peers = testCase.peers
|
||||
@@ -415,6 +423,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
|
||||
store.Close(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -422,7 +432,15 @@ func TestNewAccount(t *testing.T) {
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain)
|
||||
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
|
||||
err := newAccountWithId(context.Background(), store, accountID, userId, domain)
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "failed to get account")
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
@@ -433,16 +451,16 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
if accountID == "" {
|
||||
t.Fatalf("expected to create an account for a user %s", userID)
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccountByUser(context.Background(), userID)
|
||||
account, err := manager.Store.GetAccountByUser(context.Background(), userID)
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
|
||||
return
|
||||
@@ -665,15 +683,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
// as initAccount was created without account id we have to take the id after account initialization
|
||||
// that happens inside the GetAccountIDByUserID where the id is getting generated
|
||||
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||
|
||||
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get init account failed")
|
||||
|
||||
@@ -689,44 +704,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account groups")
|
||||
|
||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||
require.Len(t, accountGroups, 1, "only ALL group should exists")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
|
||||
require.NoError(t, err, "failed to update account settings")
|
||||
|
||||
accountIDs, err := manager.Store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
|
||||
require.NoError(t, err, "failed to get account ids")
|
||||
require.Len(t, accountIDs, 1, "only one account should exist")
|
||||
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account groups")
|
||||
|
||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||
require.Len(t, accountGroups, 1, "if group claim is not set no group added from JWT")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
|
||||
require.NoError(t, err, "failed to update account settings")
|
||||
|
||||
accountIDs, err := manager.Store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
|
||||
require.NoError(t, err, "failed to get account ids")
|
||||
require.Len(t, accountIDs, 1, "only one account should exist")
|
||||
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to check account existence")
|
||||
require.True(t, exists, "account should exist")
|
||||
|
||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||
accountGroups, err := manager.GetAllGroups(context.Background(), accountID, userId)
|
||||
require.NoError(t, err, "failed to get account groups")
|
||||
require.Len(t, accountGroups, 3, "groups should be added to the account")
|
||||
|
||||
groupsByNames := map[string]*group.Group{}
|
||||
for _, g := range account.Groups {
|
||||
for _, g := range accountGroups {
|
||||
groupsByNames[g.Name] = g
|
||||
}
|
||||
|
||||
@@ -744,60 +768,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
|
||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account.Users["someUser"] = &User{
|
||||
Id: "someUser",
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
HashedToken: encodedHashedToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
|
||||
userPAT := &PersonalAccessToken{
|
||||
ID: "tokenId",
|
||||
UserID: "testuser",
|
||||
HashedToken: encodedHashedToken,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
|
||||
require.NoError(t, err, "failed to save PAT")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
}
|
||||
|
||||
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
|
||||
user, pat, _, _, err := am.GetAccountInfoFromPAT(context.Background(), token)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting Account from PAT: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, "account_id", account.Id)
|
||||
assert.Equal(t, "someUser", user.Id)
|
||||
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
|
||||
assert.Equal(t, "account_id", user.AccountID)
|
||||
assert.Equal(t, "testuser", user.Id)
|
||||
assert.Equal(t, userPAT, pat)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account.Users["someUser"] = &User{
|
||||
Id: "someUser",
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
HashedToken: encodedHashedToken,
|
||||
LastUsed: time.Time{},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
|
||||
userPAT := &PersonalAccessToken{
|
||||
ID: "tokenId",
|
||||
UserID: "someUser",
|
||||
HashedToken: encodedHashedToken,
|
||||
LastUsed: time.Time{},
|
||||
}
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
|
||||
require.NoError(t, err, "failed to save PAT")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -808,11 +825,10 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
t.Fatalf("Error when marking PAT used: %s", err)
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(context.Background(), "account_id")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting account: %s", err)
|
||||
}
|
||||
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
|
||||
userPAT, err = store.GetPATByID(context.Background(), LockingStrengthShare, userPAT.UserID, userPAT.ID)
|
||||
require.NoError(t, err, "failed to get PAT")
|
||||
|
||||
assert.True(t, !userPAT.LastUsed.IsZero())
|
||||
}
|
||||
|
||||
func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
@@ -823,15 +839,15 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
userId := "test_user"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "")
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
if accountID == "" {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccountByUser(context.Background(), userId)
|
||||
account, err := manager.Store.GetAccountByUser(context.Background(), userId)
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
|
||||
}
|
||||
@@ -850,32 +866,22 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
|
||||
userId := "test_user"
|
||||
domain := "hotmail.com"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
}
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
|
||||
require.NoError(t, err, "failed to get or create account by user")
|
||||
require.NotEmptyf(t, accountID, "expected to create an account for a user %s", userId)
|
||||
|
||||
if account != nil && account.Domain != domain {
|
||||
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain)
|
||||
}
|
||||
accDomain, _, err := manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account domain and category")
|
||||
require.Equal(t, domain, accDomain, "expected account domain to match")
|
||||
|
||||
domain = "gmail.com"
|
||||
|
||||
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("got the following error while retrieving existing acc: %v", err)
|
||||
}
|
||||
accountID, err = manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
|
||||
require.NoError(t, err, "failed to get or create account by user")
|
||||
|
||||
if account == nil {
|
||||
t.Fatalf("expected to get an account for a user %s", userId)
|
||||
}
|
||||
|
||||
if account != nil && account.Domain != domain {
|
||||
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
|
||||
}
|
||||
accDomain, _, err = manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account domain and category")
|
||||
require.Equal(t, domain, accDomain, "expected account domain to match")
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
@@ -907,12 +913,11 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
return am.Store.GetAccount(context.Background(), accountID)
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccount(t *testing.T) {
|
||||
@@ -1055,23 +1060,18 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "netbird.cloud")
|
||||
require.NoError(t, err, "failed to get or create account by user")
|
||||
|
||||
serial := account.Network.CurrentSerial() // should be 0
|
||||
network, err := manager.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account network")
|
||||
|
||||
if account.Network.Serial != 0 {
|
||||
t.Errorf("expecting account network to have an initial Serial=0")
|
||||
return
|
||||
}
|
||||
serial := network.CurrentSerial() // should be 0
|
||||
require.Equal(t, 0, int(serial), "expected account network to have an initial Serial=0")
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err, "failed to generate private key")
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
expectedUserID := userID
|
||||
|
||||
@@ -1079,16 +1079,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err, "failed to add peer")
|
||||
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "failed to get account")
|
||||
|
||||
if peer.Key != expectedPeerKey {
|
||||
t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key)
|
||||
@@ -1131,10 +1125,12 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
ID: xid.New().String(),
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
@@ -1212,10 +1208,15 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
policyID := xid.New().String()
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
ID: policyID,
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "rule",
|
||||
PolicyID: policyID,
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
@@ -1249,19 +1250,25 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
ID: "groupA",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
policyID := xid.New().String()
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
ID: policyID,
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "rule",
|
||||
PolicyID: policyID,
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
@@ -1302,19 +1309,24 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
group := group.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
ID: "groupA",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &group)
|
||||
require.NoError(t, err, "failed to save group")
|
||||
|
||||
policyID := xid.New().String()
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
ID: policyID,
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "rule",
|
||||
PolicyID: policyID,
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
@@ -1324,6 +1336,9 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
@@ -1352,7 +1367,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, userID, group.ID); err != nil {
|
||||
t.Errorf("delete group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1748,18 +1763,9 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
manager.peerLoginExpiry = &MockScheduler{
|
||||
@@ -1774,11 +1780,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
// disable expiration first
|
||||
update := peer.Copy()
|
||||
update.LoginExpirationEnabled = false
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
// enabling expiration should trigger the routine
|
||||
update.LoginExpirationEnabled = true
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1802,10 +1808,13 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
|
||||
settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
|
||||
require.NoError(t, err, "failed to get account settings")
|
||||
|
||||
settings.PeerLoginExpirationEnabled = true
|
||||
settings.PeerLoginExpiration = time.Hour
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -1822,11 +1831,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1854,10 +1860,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -1871,10 +1874,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
},
|
||||
}
|
||||
// enabling PeerLoginExpirationEnabled should trigger the expiration job
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
|
||||
require.NoError(t, err, "failed to get account settings")
|
||||
|
||||
settings.PeerLoginExpirationEnabled = true
|
||||
settings.PeerLoginExpiration = time.Hour
|
||||
settings, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1884,10 +1889,8 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
wg.Add(1)
|
||||
|
||||
// disabling PeerLoginExpirationEnabled should trigger cancel
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
settings.PeerLoginExpirationEnabled = false
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
failed = waitTimeout(wg, time.Second)
|
||||
if failed {
|
||||
@@ -1902,30 +1905,29 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
settings.PeerLoginExpirationEnabled = false
|
||||
settings.PeerLoginExpiration = time.Hour
|
||||
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
settings, err = manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
assert.False(t, settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Second,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
settings.PeerLoginExpirationEnabled = false
|
||||
settings.PeerLoginExpiration = time.Second
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour * 24 * 181,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
settings.PeerLoginExpirationEnabled = false
|
||||
settings.PeerLoginExpiration = time.Hour * 24 * 181
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
|
||||
}
|
||||
|
||||
@@ -2606,7 +2608,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2626,7 +2628,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2665,7 +2667,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, groups, 3, "new group3 should be added")
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
@@ -94,64 +99,105 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
|
||||
}
|
||||
|
||||
if dnsSettingsToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
oldSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
|
||||
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
|
||||
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
oldSettings := account.DNSSettings.Copy()
|
||||
account.DNSSettings = dnsSettingsToSave.Copy()
|
||||
|
||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
updateAccountPeers, err := am.areDNSSettingChangesAffectPeers(ctx, accountID, addedGroups, removedGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if dns settings changes affect peers: %w", err)
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave); err != nil {
|
||||
return fmt.Errorf("failed to update dns settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groupMap := make(map[string]*nbgroup.Group, len(groups))
|
||||
for _, g := range groups {
|
||||
groupMap[g.ID] = g
|
||||
}
|
||||
|
||||
for _, id := range addedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
group, ok := groupMap[id]
|
||||
if ok {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range removedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
group, ok := groupMap[id]
|
||||
if ok {
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
}
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||
func (am *DefaultAccountManager) areDNSSettingChangesAffectPeers(ctx context.Context, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||
hasPeers, err := am.anyGroupHasPeers(ctx, accountID, addedGroups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return am.anyGroupHasPeers(ctx, accountID, removedGroups)
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
|
||||
@@ -39,12 +39,12 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestDNSAccount(t, am)
|
||||
accountID, err := initTestDNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Fatal("failed to init testing account")
|
||||
}
|
||||
|
||||
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
|
||||
dnsSettings, err := am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
|
||||
}
|
||||
@@ -53,16 +53,12 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
t.Fatal("DNS settings for new accounts shouldn't return nil")
|
||||
}
|
||||
|
||||
account.DNSSettings = DNSSettings{
|
||||
err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &DNSSettings{
|
||||
DisabledManagementGroups: []string{group1ID},
|
||||
}
|
||||
})
|
||||
require.NoError(t, err, "failed to update DNS settings")
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("failed to save testing account with new DNS settings")
|
||||
}
|
||||
|
||||
dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
|
||||
dnsSettings, err = am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID)
|
||||
if err != nil {
|
||||
t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
|
||||
}
|
||||
@@ -71,7 +67,7 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
|
||||
}
|
||||
|
||||
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
|
||||
_, err = am.GetDNSSettings(context.Background(), accountID, dnsRegularUserID)
|
||||
if err == nil {
|
||||
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
|
||||
}
|
||||
@@ -126,12 +122,12 @@ func TestSaveDNSSettings(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestDNSAccount(t, am)
|
||||
accountID, err := initTestDNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings)
|
||||
err = am.SaveDNSSettings(context.Background(), accountID, testCase.userID, testCase.inputSettings)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
@@ -139,7 +135,7 @@ func TestSaveDNSSettings(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
updatedAccount, err := am.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
t.Errorf("should be able to retrieve updated account, got err: %s", err)
|
||||
}
|
||||
@@ -158,17 +154,17 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestDNSAccount(t, am)
|
||||
accountID, err := initTestDNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
peer1, err := account.FindPeerByPubKey(dnsPeer1Key)
|
||||
peer1, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer1Key)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
peer2, err := account.FindPeerByPubKey(dnsPeer2Key)
|
||||
peer2, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer2Key)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
@@ -179,11 +175,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
|
||||
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
|
||||
|
||||
dnsSettings := account.DNSSettings.Copy()
|
||||
accountDNSSettings, err := am.Store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account DNS settings")
|
||||
|
||||
dnsSettings := accountDNSSettings.Copy()
|
||||
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
|
||||
account.DNSSettings = dnsSettings
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &dnsSettings)
|
||||
require.NoError(t, err, "failed to update DNS settings")
|
||||
|
||||
updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
|
||||
require.NoError(t, err)
|
||||
@@ -222,7 +220,7 @@ func createDNSStore(t *testing.T) (Store, error) {
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
|
||||
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
|
||||
t.Helper()
|
||||
peer1 := &nbpeer.Peer{
|
||||
Key: dnsPeer1Key,
|
||||
@@ -257,64 +255,65 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
|
||||
|
||||
account.Users[dnsRegularUserID] = &User{
|
||||
Id: dnsRegularUserID,
|
||||
Role: UserRoleUser,
|
||||
err := newAccountWithId(context.Background(), am.Store, dnsAccountID, dnsAdminUserID, domain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: dnsRegularUserID,
|
||||
AccountID: dnsAccountID,
|
||||
Role: UserRoleUser,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(context.Background(), account.Id)
|
||||
peer1, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer1.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
peer1, err = account.FindPeerByPubKey(peer1.Key)
|
||||
_, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer2.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, err = account.FindPeerByPubKey(peer2.Key)
|
||||
err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{
|
||||
{
|
||||
ID: dnsGroup1ID,
|
||||
AccountID: dnsAccountID,
|
||||
Peers: []string{peer1.ID},
|
||||
Name: dnsGroup1ID,
|
||||
},
|
||||
{
|
||||
ID: dnsGroup2ID,
|
||||
AccountID: dnsAccountID,
|
||||
Name: dnsGroup2ID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
newGroup1 := &group.Group{
|
||||
ID: dnsGroup1ID,
|
||||
Peers: []string{peer1.ID},
|
||||
Name: dnsGroup1ID,
|
||||
}
|
||||
|
||||
newGroup2 := &group.Group{
|
||||
ID: dnsGroup2ID,
|
||||
Name: dnsGroup2ID,
|
||||
}
|
||||
|
||||
account.Groups[newGroup1.ID] = newGroup1
|
||||
account.Groups[newGroup2.ID] = newGroup2
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
allGroup, err := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, dnsAccountID, "All")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{
|
||||
ID: dnsNSGroup1,
|
||||
Name: "ns-group-1",
|
||||
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &dns.NameServerGroup{
|
||||
ID: dnsNSGroup1,
|
||||
AccountID: dnsAccountID,
|
||||
Name: "ns-group-1",
|
||||
NameServers: []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(savedPeer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
@@ -323,14 +322,12 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
Primary: true,
|
||||
Enabled: true,
|
||||
Groups: []string{allGroup.ID},
|
||||
}
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
return dnsAccountID, nil
|
||||
}
|
||||
|
||||
func generateTestData(size int) nbdns.Config {
|
||||
|
||||
@@ -20,10 +20,10 @@ var (
|
||||
)
|
||||
|
||||
type ephemeralPeer struct {
|
||||
id string
|
||||
account *Account
|
||||
deadline time.Time
|
||||
next *ephemeralPeer
|
||||
id string
|
||||
accountID string
|
||||
deadline time.Time
|
||||
next *ephemeralPeer
|
||||
}
|
||||
|
||||
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
|
||||
@@ -104,12 +104,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
|
||||
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
|
||||
|
||||
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
@@ -117,7 +111,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
return
|
||||
}
|
||||
|
||||
e.addPeer(peer.ID, a, newDeadLine())
|
||||
e.addPeer(peer.AccountID, peer.ID, newDeadLine())
|
||||
if e.timer == nil {
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
|
||||
e.cleanup(ctx)
|
||||
@@ -126,17 +120,21 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
accounts := e.store.GetAllAccounts(context.Background())
|
||||
peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
t := newDeadLine()
|
||||
count := 0
|
||||
for _, a := range accounts {
|
||||
for id, p := range a.Peers {
|
||||
if p.Ephemeral {
|
||||
count++
|
||||
e.addPeer(id, a, t)
|
||||
}
|
||||
for _, p := range peers {
|
||||
if p.Ephemeral {
|
||||
count++
|
||||
e.addPeer(p.AccountID, p.ID, t)
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
|
||||
}
|
||||
|
||||
@@ -170,18 +168,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
|
||||
for id, p := range deletePeers {
|
||||
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
|
||||
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator)
|
||||
err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) {
|
||||
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
|
||||
ep := &ephemeralPeer{
|
||||
id: id,
|
||||
account: account,
|
||||
deadline: deadline,
|
||||
id: peerID,
|
||||
accountID: accountID,
|
||||
deadline: deadline,
|
||||
}
|
||||
|
||||
if e.headPeer == nil {
|
||||
|
||||
@@ -7,25 +7,12 @@ import (
|
||||
"time"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type MockStore struct {
|
||||
Store
|
||||
account *Account
|
||||
}
|
||||
|
||||
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
|
||||
return []*Account{s.account}
|
||||
}
|
||||
|
||||
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
|
||||
_, ok := s.account.Peers[peerId]
|
||||
if ok {
|
||||
return s.account, nil
|
||||
}
|
||||
|
||||
return nil, status.NewPeerNotFoundError(peerId)
|
||||
accountID string
|
||||
}
|
||||
|
||||
type MocAccountManager struct {
|
||||
@@ -33,9 +20,8 @@ type MocAccountManager struct {
|
||||
store *MockStore
|
||||
}
|
||||
|
||||
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
|
||||
delete(a.store.account.Peers, peerID)
|
||||
return nil //nolint:nil
|
||||
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, _ string) error {
|
||||
return a.store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID)
|
||||
}
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
@@ -44,23 +30,26 @@ func TestNewManager(t *testing.T) {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
store := &MockStore{
|
||||
Store: newStore(t),
|
||||
}
|
||||
am := MocAccountManager{
|
||||
store: store,
|
||||
}
|
||||
|
||||
numberOfPeers := 5
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
require.NoError(t, err, "failed to seed peers")
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
if len(store.account.Peers) != numberOfPeers {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
|
||||
}
|
||||
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
|
||||
require.NoError(t, err, "failed to get account peers")
|
||||
require.Equal(t, numberOfPeers, len(peers), "failed to cleanup ephemeral peers")
|
||||
}
|
||||
|
||||
func TestNewManagerPeerConnected(t *testing.T) {
|
||||
@@ -69,26 +58,32 @@ func TestNewManagerPeerConnected(t *testing.T) {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
store := &MockStore{
|
||||
Store: newStore(t),
|
||||
}
|
||||
am := MocAccountManager{
|
||||
store: store,
|
||||
}
|
||||
|
||||
numberOfPeers := 5
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
require.NoError(t, err, "failed to seed peers")
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0")
|
||||
require.NoError(t, err, "failed to get peer")
|
||||
|
||||
mgr.OnPeerConnected(context.Background(), peer)
|
||||
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
expected := numberOfPeers + 1
|
||||
if len(store.account.Peers) != expected {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
|
||||
}
|
||||
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
|
||||
require.NoError(t, err, "failed to get account peers")
|
||||
require.Equal(t, numberOfPeers+1, len(peers), "failed to cleanup ephemeral peers")
|
||||
}
|
||||
|
||||
func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
@@ -97,50 +92,73 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
store := &MockStore{
|
||||
Store: newStore(t),
|
||||
}
|
||||
am := MocAccountManager{
|
||||
store: store,
|
||||
}
|
||||
|
||||
numberOfPeers := 5
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
require.NoError(t, err, "failed to seed peers")
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
for _, v := range store.account.Peers {
|
||||
mgr.OnPeerConnected(context.Background(), v)
|
||||
|
||||
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
|
||||
require.NoError(t, err, "failed to get account peers")
|
||||
for _, v := range peers {
|
||||
mgr.OnPeerConnected(context.Background(), v)
|
||||
}
|
||||
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0")
|
||||
require.NoError(t, err, "failed to get peer")
|
||||
mgr.OnPeerDisconnected(context.Background(), peer)
|
||||
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
peers, err = store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
|
||||
require.NoError(t, err, "failed to get account peers")
|
||||
expected := numberOfPeers + numberOfEphemeralPeers - 1
|
||||
if len(store.account.Peers) != expected {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
|
||||
}
|
||||
require.Equal(t, expected, len(peers), "failed to cleanup ephemeral peers")
|
||||
}
|
||||
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "")
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) error {
|
||||
accountID := "my account"
|
||||
err := newAccountWithId(context.Background(), store, accountID, "", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
store.accountID = accountID
|
||||
|
||||
for i := 0; i < numberOfPeers; i++ {
|
||||
peerId := fmt.Sprintf("peer_%d", i)
|
||||
p := &nbpeer.Peer{
|
||||
ID: peerId,
|
||||
AccountID: accountID,
|
||||
Ephemeral: false,
|
||||
}
|
||||
store.account.Peers[p.ID] = p
|
||||
err = store.AddPeerToAccount(context.Background(), p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < numberOfEphemeralPeers; i++ {
|
||||
peerId := fmt.Sprintf("ephemeral_peer_%d", i)
|
||||
p := &nbpeer.Peer{
|
||||
ID: peerId,
|
||||
AccountID: accountID,
|
||||
Ephemeral: true,
|
||||
}
|
||||
store.account.Peers[p.ID] = p
|
||||
err = store.AddPeerToAccount(context.Background(), p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -37,8 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
|
||||
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() && settings.RegularUsersViewBlocked {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -50,7 +54,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
}
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
@@ -59,31 +63,34 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.Store.GetAccountGroups(ctx, accountID)
|
||||
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
|
||||
}
|
||||
|
||||
// SaveGroups adds new groups to the account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
var (
|
||||
eventsToStore []func()
|
||||
groupsToSave []*nbgroup.Group
|
||||
)
|
||||
|
||||
for _, newGroup := range newGroups {
|
||||
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
|
||||
@@ -91,7 +98,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
existingGroup, err := account.FindGroupByName(newGroup.Name)
|
||||
existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
|
||||
if err != nil {
|
||||
s, ok := status.FromError(err)
|
||||
if !ok || s.ErrorType != status.NotFound {
|
||||
@@ -109,15 +116,15 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
if account.Peers[peerID] == nil {
|
||||
if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
oldGroup := account.Groups[newGroup.ID]
|
||||
account.Groups[newGroup.ID] = newGroup
|
||||
newGroup.AccountID = accountID
|
||||
groupsToSave = append(groupsToSave, newGroup)
|
||||
|
||||
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
|
||||
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
}
|
||||
|
||||
@@ -126,30 +133,45 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
newGroupIDs = append(newGroupIDs, newGroup.ID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, newGroupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, newGroupIDs) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
|
||||
}
|
||||
|
||||
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil {
|
||||
return fmt.Errorf("failed to save groups: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
|
||||
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
addedPeers := make([]string, 0)
|
||||
removedPeers := make([]string, 0)
|
||||
|
||||
if oldGroup != nil {
|
||||
oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
|
||||
if err == nil && oldGroup != nil {
|
||||
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||
} else {
|
||||
@@ -159,12 +181,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
||||
})
|
||||
}
|
||||
|
||||
for _, p := range addedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
for _, peerID := range addedPeers {
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
|
||||
continue
|
||||
}
|
||||
|
||||
peerCopy := peer // copy to avoid closure issues
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
|
||||
@@ -175,12 +198,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
||||
})
|
||||
}
|
||||
|
||||
for _, p := range removedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
for _, peerID := range removedPeers {
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
|
||||
continue
|
||||
}
|
||||
|
||||
peerCopy := peer // copy to avoid closure issues
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
|
||||
@@ -210,119 +234,108 @@ func difference(a, b []string) []string {
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers.
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return nil
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if allGroup.ID == groupID {
|
||||
if group.Name == "All" {
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(account.Groups, groupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
if err = am.validateDeleteGroup(ctx, group, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID); err != nil {
|
||||
return fmt.Errorf("failed to delete group: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroups deletes groups from an account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
//
|
||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
var (
|
||||
allErrors error
|
||||
groupIDsToDelete []string
|
||||
deletedGroups []*nbgroup.Group
|
||||
)
|
||||
|
||||
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := validateDeleteGroup(account, group, userId); err != nil {
|
||||
if err := am.validateDeleteGroup(ctx, group, userID); err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
|
||||
continue
|
||||
}
|
||||
|
||||
delete(account.Groups, groupID)
|
||||
groupIDsToDelete = append(groupIDsToDelete, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil {
|
||||
return fmt.Errorf("failed to delete group: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, g := range deletedGroups {
|
||||
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
for _, group := range deletedGroups {
|
||||
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||
}
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groups := make([]*nbgroup.Group, 0, len(account.Groups))
|
||||
for _, item := range account.Groups {
|
||||
groups = append(groups, item)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
add := true
|
||||
for _, itemID := range group.Peers {
|
||||
if itemID == peerID {
|
||||
@@ -334,13 +347,27 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
|
||||
}
|
||||
|
||||
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
|
||||
return fmt.Errorf("failed to save group: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -348,41 +375,55 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
updated := false
|
||||
for i, itemID := range group.Peers {
|
||||
if itemID == peerID {
|
||||
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
updated = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if areGroupChangesAffectPeers(account, []string{group.ID}) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if !updated {
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
|
||||
}
|
||||
|
||||
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
|
||||
return fmt.Errorf("failed to save group: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
|
||||
func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error {
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if group.Issued == nbgroup.GroupIssuedIntegration {
|
||||
executingUser := account.Users[userID]
|
||||
if executingUser == nil {
|
||||
executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||
@@ -390,32 +431,42 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string)
|
||||
}
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
|
||||
if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||
}
|
||||
|
||||
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
|
||||
if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"name server groups", linkedDns.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
|
||||
if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"policy", linkedPolicy.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
|
||||
if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"setup key", linkedSetupKey.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
|
||||
if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"user", linkedUser.Id}
|
||||
}
|
||||
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
|
||||
dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
|
||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||
}
|
||||
|
||||
if account.Settings.Extra != nil {
|
||||
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if settings.Extra != nil {
|
||||
if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||
return &GroupLinkError{"integrated validator", group.Name}
|
||||
}
|
||||
}
|
||||
@@ -424,17 +475,30 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string)
|
||||
}
|
||||
|
||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
|
||||
func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accountID string, groupID string) (bool, *route.Route) {
|
||||
routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
||||
return true, r
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
||||
func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, accountID string, groupID string) (bool, *Policy) {
|
||||
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
|
||||
@@ -446,7 +510,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
||||
}
|
||||
|
||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, dns := range nameServerGroups {
|
||||
for _, g := range dns.Groups {
|
||||
if g == groupID {
|
||||
@@ -454,11 +524,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
|
||||
func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, accountID string, groupID string) (bool, *SetupKey) {
|
||||
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, setupKey := range setupKeys {
|
||||
if slices.Contains(setupKey.AutoGroups, groupID) {
|
||||
return true, setupKey
|
||||
@@ -468,7 +545,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
|
||||
}
|
||||
|
||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accountID string, groupID string) (bool, *User) {
|
||||
users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
if slices.Contains(user.AutoGroups, groupID) {
|
||||
return true, user
|
||||
@@ -478,30 +561,46 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
}
|
||||
|
||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||
func (am *DefaultAccountManager) anyGroupHasPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
|
||||
for _, groupID := range groupIDs {
|
||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||
return true
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
|
||||
return true
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
|
||||
return true
|
||||
if group.HasPeers() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||
func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := am.isGroupLinkedToDns(ctx, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := am.isGroupLinkedToPolicy(ctx, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := am.isGroupLinkedToRoute(ctx, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -328,25 +328,30 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
|
||||
}
|
||||
|
||||
routeResource := &route.Route{
|
||||
ID: "example route",
|
||||
Groups: []string{groupForRoute.ID},
|
||||
ID: "example route",
|
||||
AccountID: accountID,
|
||||
Groups: []string{groupForRoute.ID},
|
||||
}
|
||||
|
||||
routePeerGroupResource := &route.Route{
|
||||
ID: "example route with peer groups",
|
||||
AccountID: accountID,
|
||||
PeerGroups: []string{groupForRoute2.ID},
|
||||
}
|
||||
|
||||
nameServerGroup := &nbdns.NameServerGroup{
|
||||
ID: "example name server group",
|
||||
Groups: []string{groupForNameServerGroups.ID},
|
||||
ID: "example name server group",
|
||||
AccountID: accountID,
|
||||
Groups: []string{groupForNameServerGroups.ID},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
ID: "example policy",
|
||||
ID: "example policy",
|
||||
AccountID: accountID,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "example policy rule",
|
||||
PolicyID: "example policy",
|
||||
Destinations: []string{groupForPolicies.ID},
|
||||
},
|
||||
},
|
||||
@@ -354,35 +359,60 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
|
||||
|
||||
setupKey := &SetupKey{
|
||||
Id: "example setup key",
|
||||
AccountID: accountID,
|
||||
AutoGroups: []string{groupForSetupKeys.ID},
|
||||
}
|
||||
|
||||
user := &User{
|
||||
Id: "example user",
|
||||
AccountID: accountID,
|
||||
AutoGroups: []string{groupForUsers.ID},
|
||||
}
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account.Routes[routeResource.ID] = routeResource
|
||||
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||
account.Policies = append(account.Policies, policy)
|
||||
account.SetupKeys[setupKey.Id] = setupKey
|
||||
account.Users[user.Id] = user
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
|
||||
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routeResource)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routePeerGroupResource)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameServerGroup)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = am.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = am.Store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, user)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
err = am.SaveGroups(context.Background(), accountID, groupAdminUserID, []*nbgroup.Group{
|
||||
groupForRoute, groupForRoute2, groupForNameServerGroups, groupForPolicies,
|
||||
groupForSetupKeys, groupForUsers, groupForUsers, groupForIntegration,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), accountID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -394,24 +424,28 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
ID: "groupA",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
ID: "groupB",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
ID: "groupC",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupD",
|
||||
Name: "GroupD",
|
||||
Peers: []string{},
|
||||
ID: "groupD",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupD",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
@@ -430,9 +464,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
ID: "groupB",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -501,10 +536,13 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
// adding a group to policy
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
ID: "policy",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "rule",
|
||||
PolicyID: "policy",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
@@ -524,9 +562,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
ID: "groupA",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -593,9 +632,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
ID: "groupC",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -610,6 +650,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving group linked to route", func(t *testing.T) {
|
||||
newRoute := route.Route{
|
||||
ID: "route",
|
||||
AccountID: account.Id,
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
NetID: "superNet",
|
||||
NetworkType: route.IPv4Network,
|
||||
@@ -634,9 +675,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
ID: "groupA",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -661,9 +703,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupD",
|
||||
Name: "GroupD",
|
||||
Peers: []string{peer1.ID},
|
||||
ID: "groupD",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupD",
|
||||
Peers: []string{peer1.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
||||
@@ -100,13 +100,13 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
||||
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||
}
|
||||
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
updatedAccountSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
|
||||
resp := toAccountResponse(accountID, updatedAccountSettings)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
|
||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
|
||||
return account.Settings, nil
|
||||
},
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
@@ -39,9 +39,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||
}
|
||||
|
||||
accCopy := account.Copy()
|
||||
accCopy.UpdateSettings(newSettings)
|
||||
return accCopy, nil
|
||||
return newSettings.Copy(), nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
|
||||
@@ -49,7 +49,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -132,7 +132,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -180,7 +180,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -238,7 +238,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -68,7 +68,7 @@ func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
|
||||
|
||||
return nil, fmt.Errorf("unknown group name")
|
||||
},
|
||||
GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
GetUserPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return maps.Values(TestPeers), nil
|
||||
},
|
||||
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
|
||||
|
||||
@@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
||||
)
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
accountManager.GetAccountFromPAT,
|
||||
accountManager.GetAccountInfoFromPAT,
|
||||
jwtValidator.ValidateAndParse,
|
||||
accountManager.MarkPATUsed,
|
||||
accountManager.CheckUserAccessByJWTGroups,
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
@@ -19,8 +19,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// GetAccountFromPATFunc function
|
||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
// GetAccountInfoFromPATFunc function
|
||||
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||
@@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
getAccountFromPAT GetAccountFromPATFunc
|
||||
getAccountInfoFromPAT GetAccountInfoFromPATFunc
|
||||
validateAndParseToken ValidateAndParseTokenFunc
|
||||
markPATUsed MarkPATUsedFunc
|
||||
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||
@@ -47,7 +47,7 @@ const (
|
||||
)
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||
audience string, userIdClaim string) *AuthMiddleware {
|
||||
if userIdClaim == "" {
|
||||
@@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
getAccountFromPAT: getAccountFromPAT,
|
||||
getAccountInfoFromPAT: getAccountInfoFromPAT,
|
||||
validateAndParseToken: validateAndParseToken,
|
||||
markPATUsed: markPATUsed,
|
||||
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||
@@ -116,7 +116,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
validatedToken, err := m.validateAndParseToken(r.Context(), token)
|
||||
@@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
|
||||
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
@@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
claimMaps := jwt.MapClaims{}
|
||||
claimMaps[m.userIDClaim] = user.Id
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||
// Update the current request with the new context information.
|
||||
|
||||
@@ -33,7 +33,8 @@ var testAccount = &server.Account{
|
||||
Domain: domain,
|
||||
Users: map[string]*server.User{
|
||||
userID: {
|
||||
Id: userID,
|
||||
Id: userID,
|
||||
AccountID: accountID,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
tokenID: {
|
||||
ID: tokenID,
|
||||
@@ -49,11 +50,11 @@ var testAccount = &server.Account{
|
||||
},
|
||||
}
|
||||
|
||||
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) {
|
||||
if token == PAT {
|
||||
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
||||
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
|
||||
}
|
||||
return nil, nil, nil, fmt.Errorf("PAT invalid")
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
||||
@@ -165,7 +166,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
)
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockGetAccountFromPAT,
|
||||
mockGetAccountInfoFromPAT,
|
||||
mockValidateAndParseToken,
|
||||
mockMarkPATUsed,
|
||||
mockCheckUserAccessByJWTGroups,
|
||||
|
||||
@@ -120,6 +120,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
|
||||
updatedNSGroup := &nbdns.NameServerGroup{
|
||||
ID: nsGroupID,
|
||||
AccountID: accountID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Primary: req.Primary,
|
||||
|
||||
@@ -48,8 +48,8 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
return peerToReturn, nil
|
||||
}
|
||||
|
||||
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
|
||||
func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
@@ -62,11 +62,16 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
groupsInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
groupsInfo := toGroupsInfo(peerGroups)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
@@ -75,7 +80,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
||||
}
|
||||
|
||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
req := &api.PeerRequest{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@@ -99,16 +104,21 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(peerGroups)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
@@ -149,18 +159,11 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
case http.MethodDelete:
|
||||
h.deletePeer(r.Context(), accountID, userID, peerID, w)
|
||||
return
|
||||
case http.MethodGet, http.MethodPut:
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == http.MethodGet {
|
||||
h.getPeer(r.Context(), account, peerID, userID, w)
|
||||
} else {
|
||||
h.updatePeer(r.Context(), account, userID, peerID, w, r)
|
||||
}
|
||||
case http.MethodGet:
|
||||
h.getPeer(r.Context(), accountID, peerID, userID, w)
|
||||
return
|
||||
case http.MethodPut:
|
||||
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
|
||||
return
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
@@ -176,7 +179,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||
peers, err := h.accountManager.ListPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -184,19 +187,25 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
||||
for _, peer := range account.Peers {
|
||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||
for _, peer := range peers {
|
||||
peerToReturn, err := h.checkPeerStatus(peer)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(peerGroups)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||
}
|
||||
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
@@ -259,16 +268,16 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
}
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
@@ -303,26 +312,14 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
||||
}
|
||||
}
|
||||
|
||||
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
|
||||
var groupsInfo []api.GroupMinimum
|
||||
groupsChecked := make(map[string]struct{})
|
||||
func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum {
|
||||
groupsInfo := make([]api.GroupMinimum, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
_, ok := groupsChecked[group.ID]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
groupsChecked[group.ID] = struct{}{}
|
||||
for _, pk := range group.Peers {
|
||||
if pk == peerID {
|
||||
info := api.GroupMinimum{
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
}
|
||||
groupsInfo = append(groupsInfo, info)
|
||||
break
|
||||
}
|
||||
}
|
||||
groupsInfo = append(groupsInfo, api.GroupMinimum{
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
})
|
||||
}
|
||||
return groupsInfo
|
||||
}
|
||||
|
||||
@@ -39,6 +39,68 @@ const (
|
||||
)
|
||||
|
||||
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer.Copy()
|
||||
}
|
||||
|
||||
policy := &server.Policy{
|
||||
ID: "policy",
|
||||
AccountID: "test_id",
|
||||
Name: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*server.PolicyRule{
|
||||
{
|
||||
ID: "rule",
|
||||
Name: "rule",
|
||||
Enabled: true,
|
||||
Action: "accept",
|
||||
Destinations: []string{"group1"},
|
||||
Sources: []string{"group1"},
|
||||
Bidirectional: true,
|
||||
Protocol: "all",
|
||||
Ports: []string{"80"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srvUser := server.NewRegularUser(serviceUser)
|
||||
srvUser.IsServiceUser = true
|
||||
|
||||
account := &server.Account{
|
||||
Id: "test_id",
|
||||
Domain: "hotmail.com",
|
||||
Peers: peersMap,
|
||||
Users: map[string]*server.User{
|
||||
adminUser: server.NewAdminUser(adminUser),
|
||||
regularUser: server.NewRegularUser(regularUser),
|
||||
serviceUser: srvUser,
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
AccountID: "test_id",
|
||||
Name: "group1",
|
||||
Issued: "api",
|
||||
Peers: maps.Keys(peersMap),
|
||||
},
|
||||
},
|
||||
Settings: &server.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
},
|
||||
Policies: []*server.Policy{policy},
|
||||
Network: &server.Network{
|
||||
Identifier: "ciclqisab2ss43jdn8q0",
|
||||
Net: net.IPNet{
|
||||
IP: net.ParseIP("100.67.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 0, 0),
|
||||
},
|
||||
Serial: 51,
|
||||
},
|
||||
}
|
||||
|
||||
return &PeersHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
@@ -64,77 +126,37 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
}
|
||||
return p, nil
|
||||
},
|
||||
GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
GetUserPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return peers, nil
|
||||
},
|
||||
ListPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return peers, nil
|
||||
},
|
||||
GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
|
||||
peersID := make([]string, len(peers))
|
||||
for _, peer := range peers {
|
||||
peersID = append(peersID, peer.ID)
|
||||
}
|
||||
return []*nbgroup.Group{
|
||||
{
|
||||
ID: "group1",
|
||||
AccountID: accountID,
|
||||
Name: "group1",
|
||||
Issued: "api",
|
||||
Peers: peersID,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
GetDNSDomainFunc: func() string {
|
||||
return "netbird.selfhosted"
|
||||
},
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) {
|
||||
return account, nil
|
||||
},
|
||||
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer.Copy()
|
||||
}
|
||||
|
||||
policy := &server.Policy{
|
||||
ID: "policy",
|
||||
AccountID: accountID,
|
||||
Name: "policy",
|
||||
Enabled: true,
|
||||
Rules: []*server.PolicyRule{
|
||||
{
|
||||
ID: "rule",
|
||||
Name: "rule",
|
||||
Enabled: true,
|
||||
Action: "accept",
|
||||
Destinations: []string{"group1"},
|
||||
Sources: []string{"group1"},
|
||||
Bidirectional: true,
|
||||
Protocol: "all",
|
||||
Ports: []string{"80"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srvUser := server.NewRegularUser(serviceUser)
|
||||
srvUser.IsServiceUser = true
|
||||
|
||||
account := &server.Account{
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
Peers: peersMap,
|
||||
Users: map[string]*server.User{
|
||||
adminUser: server.NewAdminUser(adminUser),
|
||||
regularUser: server.NewRegularUser(regularUser),
|
||||
serviceUser: srvUser,
|
||||
},
|
||||
Groups: map[string]*nbgroup.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
AccountID: accountID,
|
||||
Name: "group1",
|
||||
Issued: "api",
|
||||
Peers: maps.Keys(peersMap),
|
||||
},
|
||||
},
|
||||
Settings: &server.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: time.Hour,
|
||||
},
|
||||
Policies: []*server.Policy{policy},
|
||||
Network: &server.Network{
|
||||
Identifier: "ciclqisab2ss43jdn8q0",
|
||||
Net: net.IPNet{
|
||||
IP: net.ParseIP("100.67.0.0"),
|
||||
Mask: net.IPv4Mask(255, 255, 0, 0),
|
||||
},
|
||||
Serial: 51,
|
||||
},
|
||||
}
|
||||
|
||||
return account, nil
|
||||
},
|
||||
HasConnectedChannelFunc: func(peerID string) bool {
|
||||
|
||||
@@ -130,6 +130,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
||||
|
||||
policy := server.Policy{
|
||||
ID: policyID,
|
||||
AccountID: accountID,
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Description: req.Description,
|
||||
|
||||
@@ -163,13 +163,16 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
|
||||
}
|
||||
}
|
||||
|
||||
isUpdate := postureChecksID != ""
|
||||
|
||||
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
postureChecks.AccountID = accountID
|
||||
|
||||
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
|
||||
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks, isUpdate); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
}
|
||||
return p, nil
|
||||
},
|
||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks, _ bool) error {
|
||||
postureChecks.ID = "postureCheck"
|
||||
testPostureChecks[postureChecks.ID] = postureChecks
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -56,13 +58,15 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId
|
||||
if len(groups) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
accountsGroups, err := am.ListGroups(ctx, accountId)
|
||||
|
||||
accountGroups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
var found bool
|
||||
for _, accountGroup := range accountsGroups {
|
||||
for _, accountGroup := range accountGroups {
|
||||
if accountGroup.ID == group {
|
||||
found = true
|
||||
break
|
||||
@@ -76,6 +80,31 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) {
|
||||
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
|
||||
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupsMap := make(map[string]*nbgroup.Group, len(groups))
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peersMap := make(map[string]*nbpeer.Peer, len(peers))
|
||||
for _, peer := range peers {
|
||||
peersMap[peer.ID] = peer
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra)
|
||||
}
|
||||
|
||||
@@ -461,7 +461,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 10 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
Timeout: 200 * time.Second,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -22,16 +22,17 @@ import (
|
||||
)
|
||||
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
|
||||
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||
GetOrCreateAccountIDByUserFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
GetUserPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
ListPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
@@ -45,16 +46,16 @@ type MockAccountManager struct {
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error)
|
||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -89,15 +90,15 @@ type MockAccountManager struct {
|
||||
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error)
|
||||
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error
|
||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
@@ -123,7 +124,7 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str
|
||||
if am.SyncAndMarkPeerFunc != nil {
|
||||
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncAndMarkPeer is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
|
||||
@@ -131,7 +132,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
|
||||
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
|
||||
account, err := am.GetAccountFunc(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeers := make(map[string]struct{})
|
||||
for id := range account.Peers {
|
||||
approvedPeers[id] = struct{}{}
|
||||
@@ -171,16 +177,16 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID,
|
||||
return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
|
||||
}
|
||||
|
||||
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetOrCreateAccountByUser(
|
||||
// GetOrCreateAccountIDByUser mock implementation of GetOrCreateAccountIDByUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetOrCreateAccountIDByUser(
|
||||
ctx context.Context, userId, domain string,
|
||||
) (*server.Account, error) {
|
||||
if am.GetOrCreateAccountByUserFunc != nil {
|
||||
return am.GetOrCreateAccountByUserFunc(ctx, userId, domain)
|
||||
) (string, error) {
|
||||
if am.GetOrCreateAccountIDByUserFunc != nil {
|
||||
return am.GetOrCreateAccountIDByUserFunc(ctx, userId, domain)
|
||||
}
|
||||
return nil, status.Errorf(
|
||||
return "", status.Errorf(
|
||||
codes.Unimplemented,
|
||||
"method GetOrCreateAccountByUser is not implemented",
|
||||
"method GetOrCreateAccountIDByUser is not implemented",
|
||||
)
|
||||
}
|
||||
|
||||
@@ -222,19 +228,19 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId,
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
if am.GetAccountFromPATFunc != nil {
|
||||
return am.GetAccountFromPATFunc(ctx, pat)
|
||||
// GetAccountInfoFromPAT mock implementation of GetAccountInfoFromPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error) {
|
||||
if am.GetAccountInfoFromPATFunc != nil {
|
||||
return am.GetAccountInfoFromPATFunc(ctx, token)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
|
||||
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetAccountInfoFromPAT is not implemented")
|
||||
}
|
||||
|
||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||
@@ -354,14 +360,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
|
||||
}
|
||||
|
||||
// ListGroups mock implementation of ListGroups from server.AccountManager interface
|
||||
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
|
||||
if am.ListGroupsFunc != nil {
|
||||
return am.ListGroupsFunc(ctx, accountID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
|
||||
}
|
||||
|
||||
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
|
||||
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
if am.GroupAddPeerFunc != nil {
|
||||
@@ -626,12 +624,12 @@ func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, cl
|
||||
return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
|
||||
}
|
||||
|
||||
// GetPeers mocks GetPeers of the AccountManager interface
|
||||
func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
if am.GetPeersFunc != nil {
|
||||
return am.GetPeersFunc(ctx, accountID, userID)
|
||||
// GetUserPeers mocks GetUserPeers of the AccountManager interface
|
||||
func (am *MockAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
if am.GetUserPeersFunc != nil {
|
||||
return am.GetUserPeersFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented")
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUserPeers is not implemented")
|
||||
}
|
||||
|
||||
// GetDNSDomain mocks GetDNSDomain of the AccountManager interface
|
||||
@@ -675,7 +673,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) {
|
||||
if am.UpdateAccountSettingsFunc != nil {
|
||||
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
|
||||
}
|
||||
@@ -691,9 +689,9 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo
|
||||
}
|
||||
|
||||
// SyncPeer mocks SyncPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
if am.SyncPeerFunc != nil {
|
||||
return am.SyncPeerFunc(ctx, sync, account)
|
||||
return am.SyncPeerFunc(ctx, sync, accountID)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
|
||||
}
|
||||
@@ -739,9 +737,9 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
|
||||
}
|
||||
|
||||
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
|
||||
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error {
|
||||
if am.SavePostureChecksFunc != nil {
|
||||
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
|
||||
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks, isUpdate)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
||||
}
|
||||
@@ -840,3 +838,19 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
|
||||
}
|
||||
|
||||
// GetPeerGroups mocks GetPeerGroups of the AccountManager interface
|
||||
func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) {
|
||||
if am.GetPeerGroupsFunc != nil {
|
||||
return am.GetPeerGroupsFunc(ctx, accountID, peerID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
|
||||
}
|
||||
|
||||
// ListPeers mocks ListPeers of the AccountManager interface
|
||||
func (am *MockAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
if am.ListPeersFunc != nil {
|
||||
return am.ListPeersFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListPeers is not implemented")
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"slices"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -24,26 +26,31 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewUnauthorizedToViewNSGroupsError()
|
||||
}
|
||||
|
||||
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
|
||||
}
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
newNSGroup := &nbdns.NameServerGroup{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Name: name,
|
||||
Description: description,
|
||||
NameServers: nameServerList,
|
||||
@@ -54,92 +61,136 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
SearchDomainsEnabled: searchDomainEnabled,
|
||||
}
|
||||
|
||||
err = validateNameServerGroup(false, newNSGroup, account)
|
||||
err = am.validateNameServerGroup(ctx, accountID, newNSGroup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if account.NameServerGroups == nil {
|
||||
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
|
||||
}
|
||||
|
||||
account.NameServerGroups[newNSGroup.ID] = newNSGroup
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
updateAccountPeers, err := am.anyGroupHasPeers(ctx, accountID, newNSGroup.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, newNSGroup.Groups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup); err != nil {
|
||||
return fmt.Errorf("failed to create nameserver group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
}
|
||||
|
||||
// SaveNameServerGroup saves nameserver group
|
||||
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if nsGroupToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateNameServerGroup(true, nsGroupToSave, account)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
oldNSGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nsGroupToSave.AccountID = accountID
|
||||
|
||||
if err = am.validateNameServerGroup(ctx, accountID, nsGroupToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err := am.areNameServerGroupChangesAffectPeers(ctx, nsGroupToSave, oldNSGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
|
||||
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
if err = transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave); err != nil {
|
||||
return fmt.Errorf("failed to update nameserver group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nsGroup := account.NameServerGroups[nsGroupID]
|
||||
if nsGroup == nil {
|
||||
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
delete(account.NameServerGroups, nsGroupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, nsGroup.Groups) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
updateAccountPeers, err := am.anyGroupHasPeers(ctx, accountID, nsGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID); err != nil {
|
||||
return fmt.Errorf("failed to delete nameserver group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -150,39 +201,44 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewUnauthorizedToViewNSGroupsError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
||||
nsGroupID := ""
|
||||
if existingGroup {
|
||||
nsGroupID = nameserverGroup.ID
|
||||
_, found := account.NameServerGroups[nsGroupID]
|
||||
if !found {
|
||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) validateNameServerGroup(ctx context.Context, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
|
||||
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateNSList(nameserverGroup.NameServers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateGroups(nameserverGroup.Groups, account.Groups)
|
||||
nsServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateGroups(nameserverGroup.Groups, groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -190,6 +246,24 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
|
||||
return nil
|
||||
}
|
||||
|
||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||
func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(ctx context.Context, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
|
||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
hasPeers, err := am.anyGroupHasPeers(ctx, newNSGroup.AccountID, newNSGroup.Groups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return am.anyGroupHasPeers(ctx, oldNSGroup.AccountID, oldNSGroup.Groups)
|
||||
}
|
||||
|
||||
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
|
||||
if !primary && len(domains) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
|
||||
@@ -213,14 +287,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
|
||||
func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error {
|
||||
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
||||
}
|
||||
|
||||
for _, nsGroup := range nsGroupMap {
|
||||
for _, nsGroup := range groups {
|
||||
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
|
||||
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,14 +302,14 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
|
||||
}
|
||||
|
||||
func validateNSList(list []nbdns.NameServer) error {
|
||||
nsListLenght := len(list)
|
||||
if nsListLenght == 0 || nsListLenght > 3 {
|
||||
nsListLength := len(list)
|
||||
if nsListLength == 0 || nsListLength > 3 {
|
||||
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
|
||||
func validateGroups(list []string, groups []*nbgroup.Group) error {
|
||||
if len(list) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
|
||||
}
|
||||
@@ -244,13 +318,8 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
|
||||
if id == "" {
|
||||
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
|
||||
}
|
||||
found := false
|
||||
for groupID := range groups {
|
||||
if id == groupID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
found := slices.ContainsFunc(groups, func(group *nbgroup.Group) bool { return group.ID == id })
|
||||
if !found {
|
||||
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
|
||||
}
|
||||
@@ -277,11 +346,3 @@ func validateDomain(domain string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
|
||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||
return false
|
||||
}
|
||||
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -381,14 +382,14 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
accountID, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
outNSGroup, err := am.CreateNameServerGroup(
|
||||
context.Background(),
|
||||
account.Id,
|
||||
accountID,
|
||||
testCase.inputArgs.name,
|
||||
testCase.inputArgs.description,
|
||||
testCase.inputArgs.nameServers,
|
||||
@@ -408,7 +409,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
|
||||
// assign generated ID
|
||||
testCase.expectedNSGroup.ID = outNSGroup.ID
|
||||
|
||||
testCase.expectedNSGroup.AccountID = accountID
|
||||
if !testCase.expectedNSGroup.IsEqual(outNSGroup) {
|
||||
t.Errorf("new nameserver group didn't match expected ns group:\nGot %#v\nExpected:%#v\n", outNSGroup, testCase.expectedNSGroup)
|
||||
}
|
||||
@@ -609,20 +610,16 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
accountID, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("account should be saved")
|
||||
}
|
||||
testCase.existingNSGroup.AccountID = accountID
|
||||
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testCase.existingNSGroup)
|
||||
require.NoError(t, err, "failed to save existing nameserver group")
|
||||
|
||||
var nsGroupToSave *nbdns.NameServerGroup
|
||||
|
||||
if !testCase.skipCopying {
|
||||
nsGroupToSave = testCase.existingNSGroup.Copy()
|
||||
|
||||
@@ -651,22 +648,17 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave)
|
||||
|
||||
err = am.SaveNameServerGroup(context.Background(), accountID, userID, nsGroupToSave)
|
||||
testCase.errFunc(t, err)
|
||||
|
||||
if !testCase.shouldCreate {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID]
|
||||
require.True(t, saved)
|
||||
savedNSGroup, err := am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testCase.expectedNSGroup.ID)
|
||||
require.NoError(t, err, "failed to get saved nameserver group")
|
||||
|
||||
testCase.expectedNSGroup.AccountID = accountID
|
||||
if !testCase.expectedNSGroup.IsEqual(savedNSGroup) {
|
||||
t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup)
|
||||
}
|
||||
@@ -703,32 +695,25 @@ func TestDeleteNameServerGroup(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
accountID, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
account.NameServerGroups[testingNSGroup.ID] = testingNSGroup
|
||||
testingNSGroup.AccountID = accountID
|
||||
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testingNSGroup)
|
||||
require.NoError(t, err, "failed to save nameserver group")
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("failed to save account")
|
||||
}
|
||||
|
||||
err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID)
|
||||
err = am.DeleteNameServerGroup(context.Background(), accountID, testingNSGroup.ID, userID)
|
||||
if err != nil {
|
||||
t.Error("deleting nameserver group failed with error: ", err)
|
||||
}
|
||||
|
||||
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Error("failed to retrieve saved account with error: ", err)
|
||||
}
|
||||
|
||||
_, found := savedAccount.NameServerGroups[testingNSGroup.ID]
|
||||
if found {
|
||||
t.Error("nameserver group shouldn't be found after delete")
|
||||
}
|
||||
_, err = am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testingNSGroup.ID)
|
||||
require.NotNil(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok, "error should be a status error")
|
||||
assert.Equal(t, status.NotFound, sErr.Type(), "nameserver group shouldn't be found after delete")
|
||||
}
|
||||
|
||||
func TestGetNameServerGroup(t *testing.T) {
|
||||
@@ -738,12 +723,12 @@ func TestGetNameServerGroup(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
accountID, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID)
|
||||
foundGroup, err := am.GetNameServerGroup(context.Background(), accountID, testUserID, existingNSGroupID)
|
||||
if err != nil {
|
||||
t.Error("getting existing nameserver group failed with error: ", err)
|
||||
}
|
||||
@@ -752,7 +737,7 @@ func TestGetNameServerGroup(t *testing.T) {
|
||||
t.Error("got a nil group while getting nameserver group with ID")
|
||||
}
|
||||
|
||||
_, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing")
|
||||
_, err = am.GetNameServerGroup(context.Background(), accountID, testUserID, "not existing")
|
||||
if err == nil {
|
||||
t.Error("getting not existing nameserver group should return error, got nil")
|
||||
}
|
||||
@@ -784,8 +769,12 @@ func createNSStore(t *testing.T) (Store, error) {
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
|
||||
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
|
||||
t.Helper()
|
||||
accountID := "testingAcc"
|
||||
userID := testUserID
|
||||
domain := "example.com"
|
||||
|
||||
peer1 := &nbpeer.Peer{
|
||||
Key: nsGroupPeer1Key,
|
||||
Name: "test-host1@netbird.io",
|
||||
@@ -816,6 +805,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
|
||||
}
|
||||
existingNSGroup := nbdns.NameServerGroup{
|
||||
ID: existingNSGroupID,
|
||||
AccountID: accountID,
|
||||
Name: existingNSGroupName,
|
||||
Description: "",
|
||||
NameServers: []nbdns.NameServer{
|
||||
@@ -834,42 +824,42 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
accountID := "testingAcc"
|
||||
userID := testUserID
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
|
||||
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
|
||||
|
||||
newGroup1 := &nbgroup.Group{
|
||||
ID: group1ID,
|
||||
Name: group1ID,
|
||||
}
|
||||
|
||||
newGroup2 := &nbgroup.Group{
|
||||
ID: group2ID,
|
||||
Name: group2ID,
|
||||
}
|
||||
|
||||
account.Groups[newGroup1.ID] = newGroup1
|
||||
account.Groups[newGroup2.ID] = newGroup2
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &existingNSGroup)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*nbgroup.Group{
|
||||
{
|
||||
ID: group1ID,
|
||||
AccountID: accountID,
|
||||
Name: group1ID,
|
||||
},
|
||||
{
|
||||
ID: group2ID,
|
||||
AccountID: accountID,
|
||||
Name: group2ID,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func TestValidateDomain(t *testing.T) {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -50,34 +51,54 @@ type PeerLogin struct {
|
||||
ConnectionIP net.IP
|
||||
}
|
||||
|
||||
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||
// ListPeers returns a list of peers under the given account.
|
||||
func (am *DefaultAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// GetUserPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||
// the current user is not an admin.
|
||||
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
|
||||
regularUser := !user.HasAdminPower() && !user.IsServiceUser
|
||||
|
||||
if regularUser && account.Settings.RegularUsersViewBlocked {
|
||||
if user.IsRegularUser() && settings.RegularUsersViewBlocked {
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if regularUser && user.Id != peer.UserID {
|
||||
accountPeers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, peer := range accountPeers {
|
||||
if user.IsRegularUser() && user.Id != peer.UserID {
|
||||
// only display peers that belong to the current user if the current user is not an admin
|
||||
continue
|
||||
}
|
||||
@@ -86,10 +107,15 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
peersMap[peer.ID] = p
|
||||
}
|
||||
|
||||
if !regularUser {
|
||||
if user.IsAdminOrServiceUser() {
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(errGetAccountFmt, err)
|
||||
}
|
||||
|
||||
// fetch all the peers that have access to the user's peers
|
||||
for _, peer := range peers {
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
|
||||
@@ -107,37 +133,42 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
}
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error {
|
||||
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() {
|
||||
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
||||
}
|
||||
|
||||
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
if expired {
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) {
|
||||
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
|
||||
oldStatus := peer.Status.Copy()
|
||||
newStatus := oldStatus
|
||||
newStatus.LastSeen = time.Now().UTC()
|
||||
@@ -157,16 +188,14 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
err = am.Store.SavePeerLocation(account.Id, peer)
|
||||
err = am.Store.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||
err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -176,39 +205,50 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
|
||||
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer := account.GetPeer(update.ID)
|
||||
if peer == nil {
|
||||
return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID)
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, update.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sshChanged, peerLabelChanged, loginExpirationChanged, inactivityExpirationChanged bool
|
||||
|
||||
if peer.SSHEnabled != update.SSHEnabled {
|
||||
peer.SSHEnabled = update.SSHEnabled
|
||||
event := activity.PeerSSHEnabled
|
||||
if !update.SSHEnabled {
|
||||
event = activity.PeerSSHDisabled
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
sshChanged = true
|
||||
}
|
||||
|
||||
peerLabelUpdated := peer.Name != update.Name
|
||||
|
||||
if peerLabelUpdated {
|
||||
if peer.Name != update.Name {
|
||||
peer.Name = update.Name
|
||||
peerLabelChanged = true
|
||||
|
||||
existingLabels := account.getPeerDNSLabels()
|
||||
existingLabels, err := am.getPeerDNSLabels(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
|
||||
if err != nil {
|
||||
@@ -216,134 +256,107 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
peer.DNSLabel = newLabel
|
||||
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
if peer.LoginExpirationEnabled != update.LoginExpirationEnabled {
|
||||
|
||||
if !peer.AddedWithSSOLogin() {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
|
||||
}
|
||||
|
||||
peer.LoginExpirationEnabled = update.LoginExpirationEnabled
|
||||
loginExpirationChanged = true
|
||||
}
|
||||
|
||||
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
|
||||
if !peer.AddedWithSSOLogin() {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the inactivity expiration can't be updated")
|
||||
}
|
||||
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
|
||||
inactivityExpirationChanged = true
|
||||
}
|
||||
|
||||
if err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sshChanged {
|
||||
event := activity.PeerSSHEnabled
|
||||
if !peer.SSHEnabled {
|
||||
event = activity.PeerSSHDisabled
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
if peerLabelChanged {
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain()))
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
if loginExpirationChanged {
|
||||
event := activity.PeerLoginExpirationEnabled
|
||||
if !update.LoginExpirationEnabled {
|
||||
if !peer.LoginExpirationEnabled {
|
||||
event = activity.PeerLoginExpirationDisabled
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
|
||||
|
||||
if !peer.AddedWithSSOLogin() {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
|
||||
}
|
||||
|
||||
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
|
||||
|
||||
if inactivityExpirationChanged {
|
||||
event := activity.PeerInactivityExpirationEnabled
|
||||
if !update.InactivityExpirationEnabled {
|
||||
if !peer.InactivityExpirationEnabled {
|
||||
event = activity.PeerInactivityExpirationDisabled
|
||||
}
|
||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if peerLabelUpdated {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock
|
||||
func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error {
|
||||
|
||||
// the first loop is needed to ensure all peers present under the account before modifying, otherwise
|
||||
// we might have some inconsistencies
|
||||
peers := make([]*nbpeer.Peer, 0, len(peerIDs))
|
||||
for _, peerID := range peerIDs {
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return status.Errorf(status.NotFound, "peer %s not found", peerID)
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
|
||||
// the 2nd loop performs the actual modification
|
||||
for _, peer := range peers {
|
||||
|
||||
err := am.integratedPeerValidator.PeerDeleted(ctx, account.Id, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.DeletePeer(peer.ID)
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID,
|
||||
&UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
// fill those field for backward compatibility
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
// new field
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: account.Network.CurrentSerial(),
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
FirewallRules: []*proto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
},
|
||||
},
|
||||
NetworkMap: &NetworkMap{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePeer removes peer from the account by its IP
|
||||
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers := isPeerInActiveGroup(account, peerID)
|
||||
if peerAccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
err = am.deletePeers(ctx, account, []string{peerID}, userID)
|
||||
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
var peer *nbpeer.Peer
|
||||
var addPeerRemovedEvents []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
peer, err = transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get peer to delete: %w", err)
|
||||
}
|
||||
|
||||
addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete peer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
for _, addPeerRemovedEvent := range addPeerRemovedEvents {
|
||||
addPeerRemovedEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -405,7 +418,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
addedByUser := false
|
||||
if len(userID) > 0 {
|
||||
addedByUser = true
|
||||
accountID, err = am.Store.GetAccountIDByUserID(userID)
|
||||
accountID, err = am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID)
|
||||
} else {
|
||||
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
|
||||
}
|
||||
@@ -436,12 +449,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
var newPeer *nbpeer.Peer
|
||||
var groupsToAdd []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
var setupKeyID string
|
||||
var setupKeyName string
|
||||
var ephemeral bool
|
||||
var groupsToAdd []string
|
||||
if addedByUser {
|
||||
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
|
||||
if err != nil {
|
||||
@@ -550,7 +563,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
@@ -584,30 +597,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
|
||||
}
|
||||
|
||||
groupsToAdd = append(groupsToAdd, allGroup.ID)
|
||||
if areGroupChangesAffectPeers(account, groupsToAdd) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, newPeer)
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||
return newPeer, networkMap, postureChecks, nil
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
|
||||
@@ -630,14 +629,14 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc
|
||||
}
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.NewPeerNotRegisteredError()
|
||||
}
|
||||
|
||||
if peer.UserID != "" {
|
||||
user, err := account.FindUser(peer.UserID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -648,48 +647,38 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
}
|
||||
}
|
||||
|
||||
if peerLoginExpired(ctx, peer, account.Settings) {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if peerLoginExpired(ctx, peer, settings) {
|
||||
return nil, nil, nil, status.NewPeerLoginExpiredError()
|
||||
}
|
||||
|
||||
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupList, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
updated := peer.UpdateMetaIfNew(sync.Meta)
|
||||
if updated {
|
||||
err = am.Store.SavePeer(ctx, account.Id, peer)
|
||||
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if sync.UpdateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
}
|
||||
|
||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
if isStatusChanged || (updated && sync.UpdateAccountPeers) {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
var postureChecks []*posture.Checks
|
||||
|
||||
if peerNotValid {
|
||||
emptyMap := &NetworkMap{
|
||||
Network: account.Network.Copy(),
|
||||
}
|
||||
return peer, emptyMap, postureChecks, nil
|
||||
}
|
||||
|
||||
if isStatusChanged {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
validPeersMap, err := am.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||
return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
|
||||
}
|
||||
|
||||
// LoginPeer logs in or registers a peer.
|
||||
@@ -764,7 +753,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
}
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -795,7 +784,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
|
||||
if shouldStorePeer {
|
||||
err = am.Store.SavePeer(ctx, accountID, peer)
|
||||
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -804,16 +793,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
unlockPeer()
|
||||
unlockPeer = nil
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if updateRemotePeers || isStatusChanged {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
|
||||
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
|
||||
}
|
||||
|
||||
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
|
||||
@@ -845,21 +829,33 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
var postureChecks []*posture.Checks
|
||||
|
||||
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := am.Store.GetAccountNetwork(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
emptyMap := &NetworkMap{
|
||||
Network: account.Network.Copy(),
|
||||
Network: network.Copy(),
|
||||
}
|
||||
return peer, emptyMap, nil, nil
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||
@@ -873,7 +869,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
|
||||
// If peer was expired before and if it reached this point, it is re-authenticated.
|
||||
// UserID is present, meaning that JWT validation passed successfully in the API layer.
|
||||
peer = peer.UpdateLastLogin()
|
||||
err = am.Store.SavePeer(ctx, peer.AccountID, peer)
|
||||
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -920,45 +916,51 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
|
||||
|
||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
||||
if user.IsRegularUser() && settings.RegularUsersViewBlocked {
|
||||
return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID)
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID)
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if admin or user owns this peer, return peer
|
||||
if user.HasAdminPower() || user.IsServiceUser || peer.UserID == userID {
|
||||
if user.IsAdminOrServiceUser() || peer.UserID == userID {
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// it is also possible that user doesn't own the peer but some of his peers have access to it,
|
||||
// this is a valid case, show the peer as well.
|
||||
userPeers, err := account.FindUserPeers(userID)
|
||||
userPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(errGetAccountFmt, err)
|
||||
}
|
||||
|
||||
for _, p := range userPeers {
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap)
|
||||
for _, aclPeer := range aclPeers {
|
||||
@@ -973,7 +975,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
|
||||
// updateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
|
||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
if am.metrics != nil {
|
||||
@@ -981,9 +983,15 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
||||
}
|
||||
}()
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
peers := account.GetPeers()
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
|
||||
return
|
||||
@@ -1007,7 +1015,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, p)
|
||||
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
@@ -1017,6 +1030,236 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||
// If there is no peer that expires this function returns false and a duration of 0.
|
||||
// This function only considers peers that haven't been expired yet and that are connected.
|
||||
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
|
||||
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if len(peersWithExpiry) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var nextExpiry *time.Duration
|
||||
for _, peer := range peersWithExpiry {
|
||||
// consider only connected peers because others will require login on connecting to the management server
|
||||
if peer.Status.LoginExpired || !peer.Status.Connected {
|
||||
continue
|
||||
}
|
||||
_, duration := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||
if nextExpiry == nil || duration < *nextExpiry {
|
||||
// if expiration is below 1s return 1s duration
|
||||
// this avoids issues with ticker that can't be set to < 0
|
||||
if duration < time.Second {
|
||||
return time.Second, true
|
||||
}
|
||||
nextExpiry = &duration
|
||||
}
|
||||
}
|
||||
|
||||
if nextExpiry == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return *nextExpiry, true
|
||||
}
|
||||
|
||||
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||
// If there is no peer that expires this function returns false and a duration of 0.
|
||||
// This function only considers peers that haven't been expired yet and that are not connected.
|
||||
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
|
||||
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if len(peersWithInactivity) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var nextExpiry *time.Duration
|
||||
for _, peer := range peersWithInactivity {
|
||||
if peer.Status.LoginExpired || peer.Status.Connected {
|
||||
continue
|
||||
}
|
||||
_, duration := peer.SessionExpired(settings.PeerInactivityExpiration)
|
||||
if nextExpiry == nil || duration < *nextExpiry {
|
||||
// if expiration is below 1s return 1s duration
|
||||
// this avoids issues with ticker that can't be set to < 0
|
||||
if duration < time.Second {
|
||||
return time.Second, true
|
||||
}
|
||||
nextExpiry = &duration
|
||||
}
|
||||
}
|
||||
|
||||
if nextExpiry == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return *nextExpiry, true
|
||||
}
|
||||
|
||||
// getExpiredPeers returns peers that have been expired.
|
||||
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
|
||||
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
for _, peer := range peersWithExpiry {
|
||||
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||
if expired {
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// getInactivePeers returns peers that have been expired by inactivity
|
||||
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
|
||||
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
for _, inactivePeer := range peersWithInactivity {
|
||||
inactive, _ := inactivePeer.SessionExpired(settings.PeerInactivityExpiration)
|
||||
if inactive {
|
||||
peers = append(peers, inactivePeer)
|
||||
}
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// GetPeerGroups returns groups that the peer is part of.
|
||||
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerGroups := make([]*nbgroup.Group, 0)
|
||||
for _, group := range groups {
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
peerGroups = append(peerGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
return peerGroups, nil
|
||||
}
|
||||
|
||||
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
|
||||
func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID string, peerID string) ([]string, error) {
|
||||
groups, err := am.GetPeerGroups(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
groupIDs = append(groupIDs, group.ID)
|
||||
}
|
||||
|
||||
return groupIDs, err
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID string) (lookupMap, error) {
|
||||
dnsLabels, err := am.Store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingLabels := make(lookupMap)
|
||||
for _, label := range dnsLabels {
|
||||
existingLabels[label] = struct{}{}
|
||||
}
|
||||
return existingLabels, nil
|
||||
}
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
// in an active DNS, route, or ACL configuration.
|
||||
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, accountID, peerID string) (bool, error) {
|
||||
peerGroupIDs, err := am.getPeerGroupIDs(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return am.areGroupChangesAffectPeers(ctx, accountID, peerGroupIDs)
|
||||
}
|
||||
|
||||
// deletePeers deletes all specified peers and sends updates to the remote peers.
|
||||
// Returns a slice of functions to save events after successful peer deletion.
|
||||
func deletePeers(ctx context.Context, am *DefaultAccountManager, store Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
|
||||
var peerDeletedEvents []func()
|
||||
|
||||
for _, peer := range peers {
|
||||
if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil {
|
||||
return nil, fmt.Errorf("failed to validate peer: %w", err)
|
||||
}
|
||||
|
||||
network, err := store.GetAccountNetwork(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account network: %w", err)
|
||||
}
|
||||
|
||||
if err = store.DeletePeer(ctx, LockingStrengthUpdate, accountID, peer.ID); err != nil {
|
||||
return nil, fmt.Errorf("failed to delete peer: %w", err)
|
||||
}
|
||||
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: network.CurrentSerial(),
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
FirewallRules: []*proto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
},
|
||||
},
|
||||
NetworkMap: &NetworkMap{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
peerDeletedEvents = append(peerDeletedEvents, func() {
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
|
||||
})
|
||||
}
|
||||
|
||||
return peerDeletedEvents, nil
|
||||
}
|
||||
|
||||
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||
labelMap := make(map[string]struct{}, len(existingLabels))
|
||||
for _, label := range existingLabels {
|
||||
@@ -1024,15 +1267,3 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||
}
|
||||
return labelMap
|
||||
}
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
// in an active DNS, route, or ACL configuration.
|
||||
func isPeerInActiveGroup(account *Account, peerID string) bool {
|
||||
peerGroupIDs := make([]string, 0)
|
||||
for _, group := range account.Groups {
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
peerGroupIDs = append(peerGroupIDs, group.ID)
|
||||
}
|
||||
}
|
||||
return areGroupChangesAffectPeers(account, peerGroupIDs)
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ type Peer struct {
|
||||
// CreatedAt records the time the peer was created
|
||||
CreatedAt time.Time
|
||||
// Indicate ephemeral peer attribute
|
||||
Ephemeral bool
|
||||
Ephemeral bool `gorm:"index"`
|
||||
// Geo location based on connection IP
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||
}
|
||||
|
||||
@@ -467,21 +467,25 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[someUser] = &User{
|
||||
Id: someUser,
|
||||
Role: UserRoleUser,
|
||||
}
|
||||
account.Settings.RegularUsersViewBlocked = false
|
||||
err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: someUser,
|
||||
AccountID: accountID,
|
||||
Role: UserRoleUser,
|
||||
})
|
||||
require.NoError(t, err, "failed to create user")
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account settings")
|
||||
|
||||
settings.RegularUsersViewBlocked = false
|
||||
err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings)
|
||||
require.NoError(t, err, "failed to save account settings")
|
||||
|
||||
// two peers one added by a regular user and one with a setup key
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@@ -535,7 +539,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
assert.NotNil(t, peer)
|
||||
|
||||
// delete the all-to-all policy so that user's peer1 has no access to peer2
|
||||
for _, policy := range account.Policies {
|
||||
accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account policies")
|
||||
|
||||
for _, policy := range accountPolicies {
|
||||
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -563,7 +570,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
assert.NotNil(t, peer)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
func TestDefaultAccountManager_GetUserPeers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
role UserRole
|
||||
@@ -654,21 +661,33 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[someUser] = &User{
|
||||
|
||||
err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: someUser,
|
||||
AccountID: accountID,
|
||||
Role: testCase.role,
|
||||
IsServiceUser: testCase.isServiceUser,
|
||||
}
|
||||
account.Policies = []*Policy{}
|
||||
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
})
|
||||
require.NoError(t, err, "failed to create user")
|
||||
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account policies")
|
||||
|
||||
for _, policy := range accountPolicies {
|
||||
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
|
||||
require.NoError(t, err, "failed to delete policy")
|
||||
}
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account settings")
|
||||
|
||||
settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings)
|
||||
require.NoError(t, err, "failed to save account settings")
|
||||
|
||||
peerKey1, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -699,7 +718,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := manager.GetPeers(context.Background(), accountID, someUser)
|
||||
peers, err := manager.GetUserPeers(context.Background(), accountID, someUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -724,10 +743,18 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
|
||||
adminUser := "account_creator"
|
||||
regularUser := "regular_user"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[regularUser] = &User{
|
||||
Id: regularUser,
|
||||
Role: UserRoleUser,
|
||||
err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: regularUser,
|
||||
AccountID: accountID,
|
||||
Role: UserRoleUser,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
// Create peers
|
||||
@@ -741,31 +768,40 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
UserID: regularUser,
|
||||
}
|
||||
account.Peers[peer.ID] = peer
|
||||
err = manager.Store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Create groups and policies
|
||||
account.Policies = make([]*Policy, 0, groups)
|
||||
for i := 0; i < groups; i++ {
|
||||
groupID := fmt.Sprintf("group-%d", i)
|
||||
group := &nbgroup.Group{
|
||||
ID: groupID,
|
||||
Name: fmt.Sprintf("Group %d", i),
|
||||
ID: groupID,
|
||||
AccountID: accountID,
|
||||
Name: fmt.Sprintf("Group %d", i),
|
||||
}
|
||||
for j := 0; j < peers/groups; j++ {
|
||||
peerIndex := i*(peers/groups) + j
|
||||
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
|
||||
}
|
||||
account.Groups[groupID] = group
|
||||
|
||||
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
// Create a policy for this group
|
||||
policy := &Policy{
|
||||
ID: fmt.Sprintf("policy-%d", i),
|
||||
Name: fmt.Sprintf("Policy for Group %d", i),
|
||||
Enabled: true,
|
||||
ID: fmt.Sprintf("policy-%d", i),
|
||||
AccountID: accountID,
|
||||
Name: fmt.Sprintf("Policy for Group %d", i),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: fmt.Sprintf("rule-%d", i),
|
||||
PolicyID: fmt.Sprintf("policy-%d", i),
|
||||
Name: fmt.Sprintf("Rule for Group %d", i),
|
||||
Enabled: true,
|
||||
Sources: []string{groupID},
|
||||
@@ -776,22 +812,23 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
|
||||
},
|
||||
},
|
||||
}
|
||||
account.Policies = append(account.Policies, policy)
|
||||
|
||||
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
account.PostureChecks = []*posture.Checks{
|
||||
{
|
||||
ID: "PostureChecksAll",
|
||||
Name: "All",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.0.1",
|
||||
},
|
||||
err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, &posture.Checks{
|
||||
ID: "PostureChecksAll",
|
||||
AccountID: accountID,
|
||||
Name: "All",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.0.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
@@ -824,9 +861,9 @@ func BenchmarkGetPeers(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := manager.GetPeers(context.Background(), accountID, userID)
|
||||
_, err := manager.GetUserPeers(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
b.Fatalf("GetPeers failed: %v", err)
|
||||
b.Fatalf("GetUserPeers failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -876,7 +913,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.updateAccountPeers(ctx, account)
|
||||
manager.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
@@ -1401,10 +1438,13 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
// Adding peer to group linked with policy should update account peers and send peer update
|
||||
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
ID: "policy",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "rule",
|
||||
PolicyID: "policy",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
|
||||
@@ -8,9 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
b "github.com/hashicorp/go-secure-stdlib/base62"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -58,7 +57,7 @@ type PersonalAccessTokenGenerated struct {
|
||||
|
||||
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
|
||||
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
|
||||
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) {
|
||||
func CreateNewPAT(name string, expirationInDays int, targetUserID, createdBy string) (*PersonalAccessTokenGenerated, error) {
|
||||
hashedToken, plainToken, err := generateNewToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -67,6 +66,7 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona
|
||||
return &PersonalAccessTokenGenerated{
|
||||
PersonalAccessToken: PersonalAccessToken{
|
||||
ID: xid.New().String(),
|
||||
UserID: targetUserID,
|
||||
Name: name,
|
||||
HashedToken: hashedToken,
|
||||
ExpirationDate: currentTime.AddDate(0, 0, expirationInDays),
|
||||
|
||||
@@ -3,13 +3,13 @@ package server
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"slices"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -171,6 +171,7 @@ type Policy struct {
|
||||
func (p *Policy) Copy() *Policy {
|
||||
c := &Policy{
|
||||
ID: p.ID,
|
||||
AccountID: p.AccountID,
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
Enabled: p.Enabled,
|
||||
@@ -211,7 +212,6 @@ func (p *Policy) ruleGroups() []string {
|
||||
groups = append(groups, rule.Sources...)
|
||||
groups = append(groups, rule.Destinations...)
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
@@ -343,30 +343,73 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
|
||||
}
|
||||
|
||||
// SavePolicy in the store
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
postureChecks, err := am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for index, rule := range policy.Rules {
|
||||
rule.Sources = getValidGroupIDs(groups, rule.Sources)
|
||||
rule.Destinations = getValidGroupIDs(groups, rule.Destinations)
|
||||
policy.Rules[index] = rule
|
||||
}
|
||||
|
||||
if policy.SourcePostureChecks != nil {
|
||||
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
||||
}
|
||||
|
||||
updateAccountPeers, err := am.arePolicyChangesAffectPeers(ctx, policy, isUpdate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
saveFunc := transaction.SavePolicy
|
||||
if !isUpdate {
|
||||
saveFunc = transaction.CreatePolicy
|
||||
}
|
||||
|
||||
if err := saveFunc(ctx, LockingStrengthUpdate, policy); err != nil {
|
||||
return fmt.Errorf("failed to save policy: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -377,7 +420,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -385,115 +428,91 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
|
||||
// DeletePolicy from the store
|
||||
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
policy, err := am.deletePolicy(account, policyID)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
policy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
updateAccountPeers, err := am.arePolicyChangesAffectPeers(ctx, policy, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
|
||||
}
|
||||
|
||||
if anyGroupHasPeers(account, policy.ruleGroups()) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if err = transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID); err != nil {
|
||||
return fmt.Errorf("failed to delete policy: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPolicies from the store
|
||||
// ListPolicies from the store.
|
||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
||||
policyIdx := -1
|
||||
for i, policy := range account.Policies {
|
||||
if policy.ID == policyID {
|
||||
policyIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if policyIdx < 0 {
|
||||
return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID)
|
||||
}
|
||||
|
||||
policy := account.Policies[policyIdx]
|
||||
account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...)
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
// savePolicy saves or updates a policy in the given account.
|
||||
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
||||
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
|
||||
for index, rule := range policyToSave.Rules {
|
||||
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
||||
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
||||
policyToSave.Rules[index] = rule
|
||||
}
|
||||
|
||||
if policyToSave.SourcePostureChecks != nil {
|
||||
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
|
||||
}
|
||||
|
||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
||||
func (am *DefaultAccountManager) arePolicyChangesAffectPeers(ctx context.Context, policy *Policy, isUpdate bool) (bool, error) {
|
||||
if isUpdate {
|
||||
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
||||
if policyIdx < 0 {
|
||||
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||
existingPolicy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, policy.AccountID, policy.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
oldPolicy := account.Policies[policyIdx]
|
||||
// Update the existing policy
|
||||
account.Policies[policyIdx] = policyToSave
|
||||
|
||||
if !policyToSave.Enabled && !oldPolicy.Enabled {
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
|
||||
|
||||
return updateAccountPeers, nil
|
||||
}
|
||||
|
||||
// Add the new policy to the account
|
||||
account.Policies = append(account.Policies, policyToSave)
|
||||
|
||||
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
|
||||
}
|
||||
|
||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
result[i] = &proto.FirewallRule{
|
||||
PeerIP: rule.PeerIP,
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
hasPeers, err := am.anyGroupHasPeers(ctx, policy.AccountID, existingPolicy.ruleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return am.anyGroupHasPeers(ctx, policy.AccountID, policy.ruleGroups())
|
||||
}
|
||||
return result
|
||||
|
||||
return am.anyGroupHasPeers(ctx, policy.AccountID, policy.ruleGroups())
|
||||
}
|
||||
|
||||
// getAllPeersFromGroups for given peer ID and list of groups
|
||||
@@ -574,27 +593,52 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterValidPostureChecks filters and returns the posture check IDs from the given list
|
||||
// that are valid within the provided account.
|
||||
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
|
||||
result := make([]string, 0, len(postureChecksIds))
|
||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||
func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string {
|
||||
validPostureCheckIDs := make(map[string]struct{})
|
||||
for _, check := range postureChecks {
|
||||
validPostureCheckIDs[check.ID] = struct{}{}
|
||||
}
|
||||
|
||||
validIDs := make([]string, 0, len(postureChecksIds))
|
||||
for _, id := range postureChecksIds {
|
||||
for _, postureCheck := range account.PostureChecks {
|
||||
if id == postureCheck.ID {
|
||||
result = append(result, id)
|
||||
continue
|
||||
}
|
||||
if _, exists := validPostureCheckIDs[id]; exists {
|
||||
validIDs = append(validIDs, id)
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
return validIDs
|
||||
}
|
||||
|
||||
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
|
||||
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
|
||||
result := make([]string, 0, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
if _, exists := account.Groups[groupID]; exists {
|
||||
result = append(result, groupID)
|
||||
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
|
||||
func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string {
|
||||
validGroupIDs := make(map[string]struct{})
|
||||
for _, group := range groups {
|
||||
validGroupIDs[group.ID] = struct{}{}
|
||||
}
|
||||
|
||||
validIDs := make([]string, 0, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if _, exists := validGroupIDs[id]; exists {
|
||||
validIDs = append(validIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
return validIDs
|
||||
}
|
||||
|
||||
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||
result := make([]*proto.FirewallRule, len(rules))
|
||||
for i := range rules {
|
||||
rule := rules[i]
|
||||
|
||||
result[i] = &proto.FirewallRule{
|
||||
PeerIP: rule.PeerIP,
|
||||
Direction: getProtoDirection(rule.Direction),
|
||||
Action: getProtoAction(rule.Action),
|
||||
Protocol: getProtoProtocol(rule.Protocol),
|
||||
Port: rule.Port,
|
||||
}
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -832,24 +832,28 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
ID: "groupA",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
ID: "groupB",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
ID: "groupC",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupD",
|
||||
Name: "GroupD",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
ID: "groupD",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupD",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
@@ -862,11 +866,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-rule-groups-no-peers",
|
||||
Enabled: true,
|
||||
ID: "policy-rule-groups-no-peers",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
PolicyID: "policy-rule-groups-no-peers",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupC"},
|
||||
@@ -896,11 +902,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// update account's peers and send peer update
|
||||
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-has-peers-destination-none",
|
||||
Enabled: true,
|
||||
ID: "policy-source-has-peers-destination-none",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
PolicyID: "policy-source-has-peers-destination-none",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
@@ -931,11 +939,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// update account's peers and send peer update
|
||||
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-destination-has-peers-source-none",
|
||||
Enabled: true,
|
||||
ID: "policy-destination-has-peers-source-none",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
PolicyID: "policy-destination-has-peers-source-none",
|
||||
Enabled: false,
|
||||
Sources: []string{"groupC"},
|
||||
Destinations: []string{"groupD"},
|
||||
@@ -966,11 +976,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// and send peer update
|
||||
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
ID: "policy-source-destination-peers",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
PolicyID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
@@ -1000,11 +1012,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// and send peer update
|
||||
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: false,
|
||||
ID: "policy-source-destination-peers",
|
||||
AccountID: account.Id,
|
||||
Enabled: false,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
PolicyID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
@@ -1035,11 +1049,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
AccountID: account.Id,
|
||||
Description: "updated description",
|
||||
Enabled: false,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
PolicyID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
@@ -1069,11 +1085,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
// and send peer update
|
||||
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||
policy := Policy{
|
||||
ID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
ID: "policy-source-destination-peers",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
PolicyID: "policy-source-destination-peers",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupD"},
|
||||
|
||||
@@ -2,16 +2,14 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
const (
|
||||
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
@@ -20,85 +18,127 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||
}
|
||||
|
||||
// SavePostureChecks saves a posture check.
|
||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
if err = am.validatePostureChecks(ctx, accountID, postureChecks); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
}
|
||||
|
||||
exists, uniqName := am.savePostureChecks(account, postureChecks)
|
||||
|
||||
// we do not allow create new posture checks with non uniq name
|
||||
if !exists && !uniqName {
|
||||
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
|
||||
updateAccountPeers, err := am.arePostureCheckChangesAffectPeers(ctx, accountID, postureChecks.ID, isUpdate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
action := activity.PostureCheckCreated
|
||||
if exists {
|
||||
action = activity.PostureCheckUpdated
|
||||
account.Network.IncSerial()
|
||||
}
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if isUpdate {
|
||||
action = activity.PostureCheckUpdated
|
||||
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
|
||||
return fmt.Errorf("failed to get posture checks: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks); err != nil {
|
||||
return fmt.Errorf("failed to save posture checks: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
|
||||
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
func (am *DefaultAccountManager) validatePostureChecks(ctx context.Context, accountID string, postureChecks *posture.Checks) error {
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
checks, err := am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
for _, check := range checks {
|
||||
if check.Name == postureChecks.Name && check.ID != postureChecks.ID {
|
||||
return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePostureChecks deletes a posture check by ID.
|
||||
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
postureChecks, err := am.deletePostureChecks(account, postureChecksID)
|
||||
postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
if err = am.isPostureCheckLinkedToPolicy(ctx, postureChecksID, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID); err != nil {
|
||||
return fmt.Errorf("failed to delete posture checks: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -107,132 +147,123 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPostureChecks returns a list of posture checks.
|
||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
||||
uniqName = true
|
||||
for i, p := range account.PostureChecks {
|
||||
if !exists && p.ID == postureChecks.ID {
|
||||
account.PostureChecks[i] = postureChecks
|
||||
exists = true
|
||||
}
|
||||
if p.Name == postureChecks.Name {
|
||||
uniqName = false
|
||||
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||
func (am *DefaultAccountManager) isPostureCheckLinkedToPolicy(ctx context.Context, postureChecksID, accountID string) error {
|
||||
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
||||
return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
account.PostureChecks = append(account.PostureChecks, postureChecks)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) {
|
||||
postureChecksIdx := -1
|
||||
for i, postureChecks := range account.PostureChecks {
|
||||
if postureChecks.ID == postureChecksID {
|
||||
postureChecksIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if postureChecksIdx < 0 {
|
||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
|
||||
}
|
||||
|
||||
// Check if posture check is linked to any policy
|
||||
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
|
||||
}
|
||||
|
||||
postureChecks := account.PostureChecks[postureChecksIdx]
|
||||
account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
|
||||
|
||||
return postureChecks, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks {
|
||||
peerPostureChecks := make(map[string]posture.Checks)
|
||||
|
||||
if len(account.PostureChecks) == 0 {
|
||||
return nil
|
||||
func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) {
|
||||
postureChecks, err := am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil || len(postureChecks) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerPostureChecks := make(map[string]*posture.Checks)
|
||||
|
||||
for _, policy := range policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if isPeerInPolicySourceGroups(peer.ID, account, policy) {
|
||||
addPolicyPostureChecks(account, policy, peerPostureChecks)
|
||||
isInGroup, err := am.isPeerInPolicySourceGroups(ctx, accountID, peerID, policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isInGroup {
|
||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||
postureCheck, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID)
|
||||
if err == nil {
|
||||
peerPostureChecks[sourcePostureCheckID] = postureCheck
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks))
|
||||
for _, check := range peerPostureChecks {
|
||||
checkCopy := check
|
||||
postureChecksList = append(postureChecksList, &checkCopy)
|
||||
}
|
||||
|
||||
return postureChecksList
|
||||
return maps.Values(peerPostureChecks), nil
|
||||
}
|
||||
|
||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
||||
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool {
|
||||
func (am *DefaultAccountManager) isPeerInPolicySourceGroups(ctx context.Context, accountID, peerID string, policy *Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group, ok := account.Groups[sourceGroup]
|
||||
if ok && slices.Contains(group.Peers, peerID) {
|
||||
return true
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err)
|
||||
return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
|
||||
}
|
||||
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
|
||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||
for _, postureCheck := range account.PostureChecks {
|
||||
if postureCheck.ID == sourcePostureCheckID {
|
||||
peerPostureChecks[sourcePostureCheckID] = *postureCheck
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
|
||||
for _, policy := range account.Policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
||||
return true, policy
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
|
||||
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||
func (am *DefaultAccountManager) arePostureCheckChangesAffectPeers(ctx context.Context, accountID, postureCheckID string, exists bool) (bool, error) {
|
||||
if !exists {
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
|
||||
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
|
||||
if !isLinked {
|
||||
return false
|
||||
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
|
||||
|
||||
for _, policy := range policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
|
||||
hasPeers, err := am.anyGroupHasPeers(ctx, accountID, policy.ruleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -5,8 +5,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
|
||||
@@ -26,41 +27,43 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestPostureChecksAccount(am)
|
||||
accountID, err := initTestPostureChecksAccount(am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
t.Run("Generic posture check flow", func(t *testing.T) {
|
||||
// regular users can not create checks
|
||||
err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
|
||||
err := am.SavePostureChecks(context.Background(), accountID, regularUserID, &posture.Checks{}, false)
|
||||
assert.Error(t, err)
|
||||
|
||||
// regular users cannot list check
|
||||
_, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID)
|
||||
_, err = am.ListPostureChecks(context.Background(), accountID, regularUserID)
|
||||
assert.Error(t, err)
|
||||
|
||||
// should be possible to create posture check with uniq name
|
||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
ID: postureCheckID,
|
||||
Name: postureCheckName,
|
||||
err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
|
||||
ID: postureCheckID,
|
||||
AccountID: accountID,
|
||||
Name: postureCheckName,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.26.0",
|
||||
},
|
||||
},
|
||||
})
|
||||
}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// admin users can list check
|
||||
checks, err := am.ListPostureChecks(context.Background(), account.Id, adminUserID)
|
||||
checks, err := am.ListPostureChecks(context.Background(), accountID, adminUserID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, checks, 1)
|
||||
|
||||
// should not be possible to create posture check with non uniq name
|
||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
ID: "new-id",
|
||||
Name: postureCheckName,
|
||||
err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
|
||||
ID: "new-id",
|
||||
AccountID: accountID,
|
||||
Name: postureCheckName,
|
||||
Checks: posture.ChecksDefinition{
|
||||
GeoLocationCheck: &posture.GeoLocationCheck{
|
||||
Locations: []posture.Location{
|
||||
@@ -70,57 +73,61 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}, false)
|
||||
assert.Error(t, err)
|
||||
|
||||
// admins can update posture checks
|
||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||
ID: postureCheckID,
|
||||
Name: postureCheckName,
|
||||
err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
|
||||
ID: postureCheckID,
|
||||
AccountID: accountID,
|
||||
Name: postureCheckName,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.27.0",
|
||||
},
|
||||
},
|
||||
})
|
||||
}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// users should not be able to delete posture checks
|
||||
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID)
|
||||
err = am.DeletePostureChecks(context.Background(), accountID, postureCheckID, regularUserID)
|
||||
assert.Error(t, err)
|
||||
|
||||
// admin should be able to delete posture checks
|
||||
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID)
|
||||
err = am.DeletePostureChecks(context.Background(), accountID, postureCheckID, adminUserID)
|
||||
assert.NoError(t, err)
|
||||
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID)
|
||||
checks, err = am.ListPostureChecks(context.Background(), accountID, adminUserID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, checks, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
func initTestPostureChecksAccount(am *DefaultAccountManager) (string, error) {
|
||||
accountID := "testingAccount"
|
||||
domain := "example.com"
|
||||
|
||||
admin := &User{
|
||||
Id: adminUserID,
|
||||
Role: UserRoleAdmin,
|
||||
}
|
||||
user := &User{
|
||||
Id: regularUserID,
|
||||
Role: UserRoleUser,
|
||||
}
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account.Users[admin.Id] = admin
|
||||
account.Users[user.Id] = user
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
err = am.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
|
||||
{
|
||||
Id: adminUserID,
|
||||
AccountID: accountID,
|
||||
Role: UserRoleAdmin,
|
||||
},
|
||||
{
|
||||
Id: regularUserID,
|
||||
AccountID: accountID,
|
||||
Role: UserRoleUser,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
@@ -128,19 +135,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
ID: "groupA",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
},
|
||||
{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
ID: "groupB",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
ID: "groupC",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupC",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
@@ -169,7 +179,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -192,7 +202,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -203,11 +213,13 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
})
|
||||
|
||||
policy := Policy{
|
||||
ID: "policyA",
|
||||
Enabled: true,
|
||||
ID: "policyA",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
ID: "ruleA",
|
||||
PolicyID: "policyA",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
@@ -255,7 +267,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -303,17 +315,19 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
|
||||
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
ID: "policyB",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
ID: "ruleB",
|
||||
PolicyID: "policyB",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupC"},
|
||||
@@ -337,7 +351,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -355,11 +369,13 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||
})
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
ID: "policyB",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
ID: "ruleB",
|
||||
PolicyID: "policyB",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupB"},
|
||||
Destinations: []string{"groupA"},
|
||||
@@ -384,7 +400,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
MinVersion: "0.29.0",
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -398,10 +414,13 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
// should trigger account peers update and send peer update
|
||||
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||
policy = Policy{
|
||||
ID: "policyB",
|
||||
Enabled: true,
|
||||
ID: "policyB",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "ruleB",
|
||||
PolicyID: "policyB",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupB"},
|
||||
@@ -429,7 +448,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -441,79 +460,126 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
|
||||
account := &Account{
|
||||
Policies: []*Policy{
|
||||
{
|
||||
ID: "policyA",
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{"checkA"},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*group.Group{
|
||||
"groupA": {
|
||||
ID: "groupA",
|
||||
Peers: []string{"peer1"},
|
||||
},
|
||||
"groupB": {
|
||||
ID: "groupB",
|
||||
Peers: []string{},
|
||||
},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{
|
||||
ID: "checkA",
|
||||
},
|
||||
{
|
||||
ID: "checkB",
|
||||
},
|
||||
},
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "failed to create account manager")
|
||||
|
||||
accountID, err := initTestPostureChecksAccount(manager)
|
||||
require.NoError(t, err, "failed to init testing account")
|
||||
|
||||
groupA := &group.Group{
|
||||
ID: "groupA",
|
||||
AccountID: accountID,
|
||||
Peers: []string{"peer1"},
|
||||
}
|
||||
|
||||
groupB := &group.Group{
|
||||
ID: "groupB",
|
||||
AccountID: accountID,
|
||||
Peers: []string{},
|
||||
}
|
||||
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
|
||||
policy := &Policy{
|
||||
ID: "policyA",
|
||||
AccountID: accountID,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "ruleA",
|
||||
PolicyID: "policyA",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"groupA"},
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{"checkA"},
|
||||
}
|
||||
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||
require.NoError(t, err, "failed to save policy")
|
||||
|
||||
postureCheckA := &posture.Checks{
|
||||
ID: "checkA",
|
||||
Name: "checkA",
|
||||
AccountID: accountID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
|
||||
},
|
||||
}
|
||||
err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckA, false)
|
||||
require.NoError(t, err, "failed to save postureCheckA")
|
||||
|
||||
postureCheckB := &posture.Checks{
|
||||
ID: "checkB",
|
||||
Name: "checkB",
|
||||
AccountID: accountID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
|
||||
},
|
||||
}
|
||||
err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckB, false)
|
||||
require.NoError(t, err, "failed to save postureCheckB")
|
||||
|
||||
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
|
||||
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkB", true)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check does not exist", func(t *testing.T) {
|
||||
result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
|
||||
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "unknown", false)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"groupB"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"groupA"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
policy.Rules[0].Sources = []string{"groupB"}
|
||||
policy.Rules[0].Destinations = []string{"groupA"}
|
||||
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||
require.NoError(t, err, "failed to update policy")
|
||||
|
||||
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"groupA"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"groupB"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
policy.Rules[0].Sources = []string{"groupA"}
|
||||
policy.Rules[0].Destinations = []string{"groupB"}
|
||||
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||
require.NoError(t, err, "failed to update policy")
|
||||
|
||||
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
||||
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
|
||||
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||
groupA.Peers = []string{}
|
||||
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
|
||||
require.NoError(t, err, "failed to save groups")
|
||||
|
||||
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result)
|
||||
})
|
||||
|
||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||
account.Groups["groupA"].Peers = []string{}
|
||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
||||
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
||||
policy.Rules[0].Sources = []string{"nonExistentGroup"}
|
||||
policy.Rules[0].Destinations = []string{"nonExistentGroup"}
|
||||
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||
require.NoError(t, err, "failed to update policy")
|
||||
|
||||
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, sErr.Type())
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -52,17 +52,43 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
|
||||
}
|
||||
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
func (am *DefaultAccountManager) GetRoutesByPrefixOrDomains(ctx context.Context, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
|
||||
accountRoutes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, r := range accountRoutes {
|
||||
dynamic := r.IsDynamic()
|
||||
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
|
||||
!dynamic && r.Network.String() == prefix.String() {
|
||||
routes = append(routes, r)
|
||||
}
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
||||
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, accountID, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
||||
// routes can have both peer and peer_groups
|
||||
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
|
||||
routesWithPrefix, err := am.GetRoutesByPrefixOrDomains(ctx, accountID, prefix, domains)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// lets remember all the peers and the peer groups from routesWithPrefix
|
||||
seenPeers := make(map[string]bool)
|
||||
@@ -81,8 +107,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
for _, groupID := range prefixRoute.PeerGroups {
|
||||
seenPeerGroups[groupID] = true
|
||||
|
||||
group := account.GetGroup(groupID)
|
||||
if group == nil {
|
||||
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
if err != nil || group == nil {
|
||||
return status.Errorf(
|
||||
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
|
||||
getRouteDescriptor(prefix, domains), groupID,
|
||||
@@ -97,10 +123,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
|
||||
if peerID != "" {
|
||||
// check that peerID exists and is not in any route as single peer or part of the group
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
|
||||
if err != nil || peer == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
if _, ok := seenPeers[peerID]; ok {
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
|
||||
@@ -109,7 +136,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
|
||||
// check that peerGroupIDs are not in any route peerGroups list
|
||||
for _, groupID := range peerGroupIDs {
|
||||
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
|
||||
// we validated the group existence before entering this function, no need to check again.
|
||||
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
if err != nil || group == nil {
|
||||
return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
if _, ok := seenPeerGroups[groupID]; ok {
|
||||
return status.Errorf(
|
||||
@@ -120,10 +151,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
||||
for _, id := range group.Peers {
|
||||
if _, ok := seenPeers[id]; ok {
|
||||
peer := account.GetPeer(id)
|
||||
if peer == nil {
|
||||
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"failed to add route with %s - peer %s from the group %s already has this route",
|
||||
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
|
||||
@@ -143,16 +175,22 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
|
||||
|
||||
// CreateRoute creates and saves a new route
|
||||
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(peerID); peer != nil {
|
||||
if peerID != "" {
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
@@ -179,22 +217,28 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
|
||||
var newRoute route.Route
|
||||
newRoute.ID = route.ID(xid.New().String())
|
||||
newRoute.AccountID = accountID
|
||||
|
||||
accountGroups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(peerGroupIDs) > 0 {
|
||||
err = validateGroups(peerGroupIDs, account.Groups)
|
||||
err = validateGroups(peerGroupIDs, accountGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(accessControlGroupIDs) > 0 {
|
||||
err = validateGroups(accessControlGroupIDs, account.Groups)
|
||||
err = validateGroups(accessControlGroupIDs, accountGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(ctx, accountID, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -207,7 +251,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
err = validateGroups(groups, account.Groups)
|
||||
err = validateGroups(groups, accountGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -226,30 +270,46 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
newRoute.KeepRoute = keepRoute
|
||||
newRoute.AccessControlGroups = accessControlGroupIDs
|
||||
|
||||
if account.Routes == nil {
|
||||
account.Routes = make(map[route.ID]*route.Route)
|
||||
}
|
||||
|
||||
account.Routes[newRoute.ID] = &newRoute
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
updateAccountPeers, err := am.areRouteChangesAffectPeers(ctx, &newRoute)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isRouteChangeAffectPeers(account, &newRoute) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
|
||||
}
|
||||
|
||||
err = transaction.SaveRoute(ctx, LockingStrengthUpdate, &newRoute)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create route: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return &newRoute, nil
|
||||
}
|
||||
|
||||
// SaveRoute saves route
|
||||
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if routeToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "route provided is nil")
|
||||
@@ -263,18 +323,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
oldRoute, err := am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeToSave.ID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
}
|
||||
|
||||
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
||||
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||
}
|
||||
@@ -291,72 +344,119 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
|
||||
}
|
||||
|
||||
// Do not allow non-Linux peers
|
||||
if routeToSave.Peer != "" {
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, routeToSave.Peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
}
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(routeToSave.PeerGroups) > 0 {
|
||||
err = validateGroups(routeToSave.PeerGroups, account.Groups)
|
||||
err = validateGroups(routeToSave.PeerGroups, groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(routeToSave.AccessControlGroups) > 0 {
|
||||
err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
|
||||
err = validateGroups(routeToSave.AccessControlGroups, groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(ctx, accountID, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateGroups(routeToSave.Groups, account.Groups)
|
||||
err = validateGroups(routeToSave.Groups, groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldRoute := account.Routes[routeToSave.ID]
|
||||
account.Routes[routeToSave.ID] = routeToSave
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
oldRouteAffectsPeers, err := am.areRouteChangesAffectPeers(ctx, oldRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
newRouteAffectsPeers, err := am.areRouteChangesAffectPeers(ctx, routeToSave)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
routeToSave.AccountID = accountID
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
|
||||
}
|
||||
|
||||
err = transaction.SaveRoute(ctx, LockingStrengthUpdate, routeToSave)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save route: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
if oldRouteAffectsPeers || newRouteAffectsPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
routy := account.Routes[routeID]
|
||||
if routy == nil {
|
||||
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
delete(account.Routes, routeID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
route, err := am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||
updateAccountPeers, err := am.areRouteChangesAffectPeers(ctx, route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isRouteChangeAffectPeers(account, routy) {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteRoute(ctx, LockingStrengthUpdate, accountID, string(routeID)); err != nil {
|
||||
return fmt.Errorf("failed to delete route: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -369,8 +469,12 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
|
||||
@@ -649,8 +753,21 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers
|
||||
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
||||
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||
// areRouteChangesAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers.
|
||||
func (am *DefaultAccountManager) areRouteChangesAffectPeers(ctx context.Context, route *route.Route) (bool, error) {
|
||||
if route.Peer != "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
hasPeers, err := am.anyGroupHasPeers(ctx, route.AccountID, route.Groups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return am.anyGroupHasPeers(ctx, route.AccountID, route.PeerGroups)
|
||||
}
|
||||
|
||||
@@ -5,19 +5,20 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -427,21 +428,22 @@ func TestCreateRoute(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestRouteAccount(t, am)
|
||||
accountID, err := initTestRouteAccount(t, am)
|
||||
if err != nil {
|
||||
t.Errorf("failed to init testing account: %s", err)
|
||||
}
|
||||
|
||||
if testCase.createInitRoute {
|
||||
groupAll, errInit := account.GetGroupAll()
|
||||
groupAll, errInit := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
|
||||
require.NoError(t, errInit)
|
||||
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
|
||||
|
||||
_, errInit = am.CreateRoute(context.Background(), accountID, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
|
||||
require.NoError(t, errInit)
|
||||
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
|
||||
_, errInit = am.CreateRoute(context.Background(), accountID, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
|
||||
require.NoError(t, errInit)
|
||||
}
|
||||
|
||||
outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
|
||||
outRoute, err := am.CreateRoute(context.Background(), accountID, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
|
||||
|
||||
testCase.errFunc(t, err)
|
||||
|
||||
@@ -451,6 +453,7 @@ func TestCreateRoute(t *testing.T) {
|
||||
|
||||
// assign generated ID
|
||||
testCase.expectedRoute.ID = outRoute.ID
|
||||
testCase.expectedRoute.AccountID = accountID
|
||||
|
||||
if !testCase.expectedRoute.IsEqual(outRoute) {
|
||||
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", outRoute, testCase.expectedRoute)
|
||||
@@ -917,14 +920,15 @@ func TestSaveRoute(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestRouteAccount(t, am)
|
||||
accountID, err := initTestRouteAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
if testCase.createInitRoute {
|
||||
account.Routes["initRoute"] = &route.Route{
|
||||
initRoute := &route.Route{
|
||||
ID: "initRoute",
|
||||
AccountID: accountID,
|
||||
Network: existingNetwork,
|
||||
NetID: existingRouteID,
|
||||
NetworkType: route.IPv4Network,
|
||||
@@ -935,14 +939,13 @@ func TestSaveRoute(t *testing.T) {
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
}
|
||||
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, initRoute)
|
||||
require.NoError(t, err, "failed to save init route")
|
||||
}
|
||||
|
||||
account.Routes[testCase.existingRoute.ID] = testCase.existingRoute
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("account should be saved")
|
||||
}
|
||||
testCase.existingRoute.AccountID = accountID
|
||||
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, testCase.existingRoute)
|
||||
require.NoError(t, err, "failed to save existing route")
|
||||
|
||||
var routeToSave *route.Route
|
||||
|
||||
@@ -977,7 +980,7 @@ func TestSaveRoute(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
err = am.SaveRoute(context.Background(), account.Id, userID, routeToSave)
|
||||
err = am.SaveRoute(context.Background(), accountID, userID, routeToSave)
|
||||
|
||||
testCase.errFunc(t, err)
|
||||
|
||||
@@ -985,14 +988,10 @@ func TestSaveRoute(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
savedRoute, saved := account.Routes[testCase.expectedRoute.ID]
|
||||
require.True(t, saved)
|
||||
savedRoute, err := am.GetRoute(context.Background(), accountID, testCase.existingRoute.ID, userID)
|
||||
require.NoError(t, err, "failed to get saved route")
|
||||
|
||||
testCase.expectedRoute.AccountID = accountID
|
||||
if !testCase.expectedRoute.IsEqual(savedRoute) {
|
||||
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute)
|
||||
}
|
||||
@@ -1001,50 +1000,48 @@ func TestSaveRoute(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDeleteRoute(t *testing.T) {
|
||||
testingRoute := &route.Route{
|
||||
ID: "testingRoute",
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
Domains: domain.List{"domain1", "domain2"},
|
||||
KeepRoute: true,
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer1Key,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
am, err := createRouterManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestRouteAccount(t, am)
|
||||
accountID, err := initTestRouteAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
account.Routes[testingRoute.ID] = testingRoute
|
||||
err = am.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "GroupA",
|
||||
AccountID: accountID,
|
||||
Name: "GroupA",
|
||||
})
|
||||
require.NoError(t, err, "failed to save group")
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("failed to save account")
|
||||
testingRoute := &route.Route{
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
NetID: route.NetID("12345678901234567890qw"),
|
||||
Groups: []string{"GroupA"},
|
||||
KeepRoute: true,
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer1ID,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
}
|
||||
createdRoute, err := am.CreateRoute(context.Background(), accountID, testingRoute.Network, testingRoute.NetworkType, testingRoute.Domains, peer1ID, []string{}, testingRoute.Description, testingRoute.NetID, testingRoute.Masquerade, testingRoute.Metric, testingRoute.Groups, testingRoute.AccessControlGroups, true, userID, testingRoute.KeepRoute)
|
||||
require.NoError(t, err, "failed to create route")
|
||||
|
||||
err = am.DeleteRoute(context.Background(), account.Id, testingRoute.ID, userID)
|
||||
err = am.DeleteRoute(context.Background(), accountID, createdRoute.ID, userID)
|
||||
if err != nil {
|
||||
t.Error("deleting route failed with error: ", err)
|
||||
}
|
||||
|
||||
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Error("failed to retrieve saved account with error: ", err)
|
||||
}
|
||||
|
||||
_, found := savedAccount.Routes[testingRoute.ID]
|
||||
if found {
|
||||
t.Error("route shouldn't be found after delete")
|
||||
}
|
||||
_, err = am.GetRoute(context.Background(), accountID, testingRoute.ID, userID)
|
||||
require.NotNil(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, sErr.Type())
|
||||
}
|
||||
|
||||
func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
@@ -1066,16 +1063,14 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestRouteAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
accountID, err := initTestRouteAccount(t, am)
|
||||
require.NoError(t, err, "failed to init testing account")
|
||||
|
||||
newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
|
||||
|
||||
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
|
||||
newRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newRoute.Enabled, true)
|
||||
|
||||
@@ -1091,7 +1086,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
|
||||
|
||||
groups, err := am.ListGroups(context.Background(), account.Id)
|
||||
groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
var groupHA1, groupHA2 *nbgroup.Group
|
||||
for _, group := range groups {
|
||||
@@ -1103,21 +1098,21 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA1.ID, peer2ID)
|
||||
err = am.GroupDeletePeer(context.Background(), accountID, groupHA1.ID, peer2ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes")
|
||||
|
||||
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID)
|
||||
err = am.GroupDeletePeer(context.Background(), accountID, groupHA2.ID, peer4ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route")
|
||||
|
||||
err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID)
|
||||
err = am.GroupAddPeer(context.Background(), accountID, groupHA2.ID, peer4ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -1128,7 +1123,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes")
|
||||
|
||||
err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID)
|
||||
err = am.DeleteRoute(context.Background(), accountID, newRoute.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -1158,16 +1153,14 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestRouteAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
accountID, err := initTestRouteAccount(t, am)
|
||||
require.NoError(t, err, "failed to init testing account")
|
||||
|
||||
newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
|
||||
|
||||
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
|
||||
createdRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -1181,7 +1174,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
expectedRoute := enabledRoute.Copy()
|
||||
expectedRoute.Peer = peer1Key
|
||||
|
||||
err = am.SaveRoute(context.Background(), account.Id, userID, enabledRoute)
|
||||
err = am.SaveRoute(context.Background(), accountID, userID, enabledRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -1193,7 +1186,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group")
|
||||
|
||||
err = am.GroupAddPeer(context.Background(), account.Id, routeGroup1, peer2ID)
|
||||
err = am.GroupAddPeer(context.Background(), accountID, routeGroup1, peer2ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID)
|
||||
@@ -1202,27 +1195,29 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group")
|
||||
|
||||
newGroup := &nbgroup.Group{
|
||||
ID: xid.New().String(),
|
||||
Name: "peer1 group",
|
||||
Peers: []string{peer1ID},
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Name: "peer1 group",
|
||||
Peers: []string{peer1ID},
|
||||
}
|
||||
err = am.SaveGroup(context.Background(), account.Id, userID, newGroup)
|
||||
err = am.SaveGroup(context.Background(), accountID, userID, newGroup)
|
||||
require.NoError(t, err)
|
||||
|
||||
rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser")
|
||||
policies, err := am.ListPolicies(context.Background(), accountID, "testingUser")
|
||||
require.NoError(t, err)
|
||||
|
||||
defaultRule := rules[0]
|
||||
defaultRule := policies[0]
|
||||
newPolicy := defaultRule.Copy()
|
||||
newPolicy.ID = xid.New().String()
|
||||
|
||||
newPolicy.Name = "peer1 only"
|
||||
newPolicy.Rules[0].Sources = []string{newGroup.ID}
|
||||
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
|
||||
|
||||
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
|
||||
err = am.SavePolicy(context.Background(), accountID, userID, newPolicy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
|
||||
err = am.DeletePolicy(context.Background(), accountID, defaultRule.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -1233,7 +1228,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2")
|
||||
|
||||
err = am.DeleteRoute(context.Background(), account.Id, enabledRoute.ID, userID)
|
||||
err = am.DeleteRoute(context.Background(), accountID, enabledRoute.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -1267,179 +1262,104 @@ func createRouterStore(t *testing.T) (Store, error) {
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
|
||||
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
|
||||
t.Helper()
|
||||
|
||||
accountID := "testingAcc"
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
ips := account.getTakenIPs()
|
||||
peer1IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
createPeer := func(peerID, peerKey, peerName, dnsLabel, kernel, core, platform, os string) (*nbpeer.Peer, error) {
|
||||
ips, err := am.Store.GetTakenIPs(context.Background(), LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
network, err := am.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerIP, err := AllocatePeerIP(network.Net, ips)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
IP: peerIP,
|
||||
AccountID: accountID,
|
||||
ID: peerID,
|
||||
Key: peerKey,
|
||||
Name: peerName,
|
||||
DNSLabel: dnsLabel,
|
||||
UserID: userID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: peerName,
|
||||
GoOS: strings.ToLower(kernel),
|
||||
Kernel: kernel,
|
||||
Core: core,
|
||||
Platform: platform,
|
||||
OS: os,
|
||||
WtVersion: "development",
|
||||
UIVersion: "development",
|
||||
},
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
}
|
||||
if err := am.Store.AddPeerToAccount(context.Background(), peer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// Create peers
|
||||
peer1, err := createPeer(peer1ID, peer1Key, "test-host1@netbird.io", "test-host1", "Linux", "21.04", "x86_64", "Ubuntu")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
peer1 := &nbpeer.Peer{
|
||||
IP: peer1IP,
|
||||
ID: peer1ID,
|
||||
Key: peer1Key,
|
||||
Name: "test-host1@netbird.io",
|
||||
DNSLabel: "test-host1",
|
||||
UserID: userID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-host1@netbird.io",
|
||||
GoOS: "linux",
|
||||
Kernel: "Linux",
|
||||
Core: "21.04",
|
||||
Platform: "x86_64",
|
||||
OS: "Ubuntu",
|
||||
WtVersion: "development",
|
||||
UIVersion: "development",
|
||||
},
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
}
|
||||
account.Peers[peer1.ID] = peer1
|
||||
|
||||
ips = account.getTakenIPs()
|
||||
peer2IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
peer2, err := createPeer(peer2ID, peer2Key, "test-host2@netbird.io", "test-host2", "Linux", "21.04", "x86_64", "Ubuntu")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
peer2 := &nbpeer.Peer{
|
||||
IP: peer2IP,
|
||||
ID: peer2ID,
|
||||
Key: peer2Key,
|
||||
Name: "test-host2@netbird.io",
|
||||
DNSLabel: "test-host2",
|
||||
UserID: userID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-host2@netbird.io",
|
||||
GoOS: "linux",
|
||||
Kernel: "Linux",
|
||||
Core: "21.04",
|
||||
Platform: "x86_64",
|
||||
OS: "Ubuntu",
|
||||
WtVersion: "development",
|
||||
UIVersion: "development",
|
||||
},
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
}
|
||||
account.Peers[peer2.ID] = peer2
|
||||
|
||||
ips = account.getTakenIPs()
|
||||
peer3IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
peer3, err := createPeer(peer3ID, peer3Key, "test-host3@netbird.io", "test-host3", "Darwin", "13.4.1", "arm64", "darwin")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
peer3 := &nbpeer.Peer{
|
||||
IP: peer3IP,
|
||||
ID: peer3ID,
|
||||
Key: peer3Key,
|
||||
Name: "test-host3@netbird.io",
|
||||
DNSLabel: "test-host3",
|
||||
UserID: userID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-host3@netbird.io",
|
||||
GoOS: "darwin",
|
||||
Kernel: "Darwin",
|
||||
Core: "13.4.1",
|
||||
Platform: "arm64",
|
||||
OS: "darwin",
|
||||
WtVersion: "development",
|
||||
UIVersion: "development",
|
||||
},
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
}
|
||||
account.Peers[peer3.ID] = peer3
|
||||
|
||||
ips = account.getTakenIPs()
|
||||
peer4IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
peer4, err := createPeer(peer4ID, peer4Key, "test-host4@netbird.io", "test-host4", "Linux", "21.04", "x86_64", "Ubuntu")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
peer4 := &nbpeer.Peer{
|
||||
IP: peer4IP,
|
||||
ID: peer4ID,
|
||||
Key: peer4Key,
|
||||
Name: "test-host4@netbird.io",
|
||||
DNSLabel: "test-host4",
|
||||
UserID: userID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-host4@netbird.io",
|
||||
GoOS: "linux",
|
||||
Kernel: "Linux",
|
||||
Core: "21.04",
|
||||
Platform: "x86_64",
|
||||
OS: "Ubuntu",
|
||||
WtVersion: "development",
|
||||
UIVersion: "development",
|
||||
},
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
}
|
||||
account.Peers[peer4.ID] = peer4
|
||||
|
||||
ips = account.getTakenIPs()
|
||||
peer5IP, err := AllocatePeerIP(account.Network.Net, ips)
|
||||
peer5, err := createPeer(peer5ID, peer5Key, "test-host5@netbird.io", "test-host5", "Linux", "21.04", "x86_64", "Ubuntu")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
peer5 := &nbpeer.Peer{
|
||||
IP: peer5IP,
|
||||
ID: peer5ID,
|
||||
Key: peer5Key,
|
||||
Name: "test-host5@netbird.io",
|
||||
DNSLabel: "test-host5",
|
||||
UserID: userID,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-host5@netbird.io",
|
||||
GoOS: "linux",
|
||||
Kernel: "Linux",
|
||||
Core: "21.04",
|
||||
Platform: "x86_64",
|
||||
OS: "Ubuntu",
|
||||
WtVersion: "development",
|
||||
UIVersion: "development",
|
||||
},
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
groupAll, err := am.GetGroupByName(context.Background(), "All", accountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
account.Peers[peer5.ID] = peer5
|
||||
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groupAll, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
newGroup := []*nbgroup.Group{
|
||||
newGroups := []*nbgroup.Group{
|
||||
{
|
||||
ID: routeGroup1,
|
||||
Name: routeGroup1,
|
||||
@@ -1471,15 +1391,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
|
||||
Peers: []string{peer1.ID, peer4.ID},
|
||||
},
|
||||
}
|
||||
|
||||
for _, group := range newGroup {
|
||||
err = am.SaveGroup(context.Background(), accountID, userID, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = am.SaveGroups(context.Background(), accountID, userID, newGroups)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
@@ -1783,10 +1700,10 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
manager, err := createRouterManager(t)
|
||||
require.NoError(t, err, "failed to create account manager")
|
||||
|
||||
account, err := initTestRouteAccount(t, manager)
|
||||
accountID, err := initTestRouteAccount(t, manager)
|
||||
require.NoError(t, err, "failed to init testing account")
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
@@ -1832,7 +1749,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
@@ -1868,7 +1785,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
@@ -1904,7 +1821,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
newRoute, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
|
||||
context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
|
||||
baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
|
||||
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute,
|
||||
)
|
||||
@@ -1928,7 +1845,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
|
||||
err := manager.SaveRoute(context.Background(), accountID, userID, &baseRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1946,7 +1863,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID)
|
||||
err := manager.DeleteRoute(context.Background(), accountID, baseRoute.ID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
@@ -1970,7 +1887,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
Groups: []string{routeGroup1},
|
||||
}
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
@@ -1982,7 +1899,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "groupB",
|
||||
Name: "GroupB",
|
||||
Peers: []string{peer1ID},
|
||||
@@ -2010,7 +1927,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
Groups: []string{"groupC"},
|
||||
}
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
@@ -2022,7 +1939,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
close(done)
|
||||
}()
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "groupC",
|
||||
Name: "GroupC",
|
||||
Peers: []string{peer1ID},
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -226,34 +226,44 @@ func Hash(s string) uint32 {
|
||||
// and adds it to the specified account. A list of autoGroups IDs can be empty.
|
||||
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil {
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = validateSetupKeyAutoGroups(groups, autoGroups); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed adding account key")
|
||||
setupKey.AccountID = accountID
|
||||
|
||||
if err = am.Store.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta())
|
||||
groupMap := make(map[string]*nbgroup.Group, len(groups))
|
||||
for _, g := range groups {
|
||||
groupMap[g.ID] = g
|
||||
}
|
||||
|
||||
for _, g := range setupKey.AutoGroups {
|
||||
group := account.GetGroup(g)
|
||||
group := groupMap[g]
|
||||
if group != nil {
|
||||
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name})
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,30 +278,30 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
// (e.g. the key itself, creation date, ID, etc).
|
||||
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
||||
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if keyToSave == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var oldKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
if key.Id == keyToSave.Id {
|
||||
oldKey = key.Copy()
|
||||
break
|
||||
}
|
||||
}
|
||||
if oldKey == nil {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil {
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = validateSetupKeyAutoGroups(groups, keyToSave.AutoGroups); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -302,9 +312,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
newKey.Revoked = keyToSave.Revoked
|
||||
newKey.UpdatedAt = time.Now().UTC()
|
||||
|
||||
account.SetupKeys[newKey.Key] = newKey
|
||||
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
if err = am.Store.SaveSetupKey(ctx, LockingStrengthUpdate, newKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -315,24 +323,30 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
defer func() {
|
||||
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
|
||||
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
|
||||
|
||||
groupMap := make(map[string]*nbgroup.Group, len(groups))
|
||||
for _, g := range groups {
|
||||
groupMap[g.ID] = g
|
||||
}
|
||||
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
group := groupMap[g]
|
||||
if group != nil {
|
||||
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, accountID)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
group := groupMap[g]
|
||||
if group != nil {
|
||||
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, accountID)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -347,8 +361,12 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.NewUnauthorizedToViewSetupKeysError()
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
|
||||
@@ -366,8 +384,12 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return nil, status.NewUnauthorizedToViewSetupKeysError()
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||
@@ -387,21 +409,25 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
||||
return status.NewUnauthorizedToViewSetupKeysError()
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get setup key: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.DeleteSetupKey(ctx, accountID, keyID)
|
||||
err = am.Store.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete setup key: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta())
|
||||
@@ -409,15 +435,22 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
|
||||
for _, group := range autoGroups {
|
||||
g, ok := account.Groups[group]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group %s doesn't exist", group)
|
||||
func validateSetupKeyAutoGroups(groups []*nbgroup.Group, autoGroups []string) error {
|
||||
groupMap := make(map[string]*nbgroup.Group, len(groups))
|
||||
for _, g := range groups {
|
||||
groupMap[g.ID] = g
|
||||
}
|
||||
|
||||
for _, groupID := range autoGroups {
|
||||
g, exists := groupMap[groupID]
|
||||
if !exists {
|
||||
return status.Errorf(status.NotFound, "group %s doesn't exist", groupID)
|
||||
}
|
||||
|
||||
if g.Name == "All" {
|
||||
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key")
|
||||
return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -25,21 +25,21 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
|
||||
require.NoError(t, err, "failed to get or create account ID")
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
ID: "group_1",
|
||||
AccountID: accountID,
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "group_2",
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
ID: "group_2",
|
||||
AccountID: accountID,
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
@@ -49,7 +49,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
expiresIn := time.Hour
|
||||
keyName := "my-test-key"
|
||||
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{},
|
||||
key, err := manager.CreateSetupKey(context.Background(), accountID, keyName, SetupKeyReusable, expiresIn, []string{},
|
||||
SetupKeyUnlimitedUsage, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -58,7 +58,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
autoGroups := []string{"group_1", "group_2"}
|
||||
newKeyName := "my-new-test-key"
|
||||
revoked := true
|
||||
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
newKey, err := manager.SaveSetupKey(context.Background(), accountID, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
@@ -72,22 +72,22 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
key.Id, time.Now().UTC(), autoGroups, true)
|
||||
|
||||
// check the corresponding events that should have been generated
|
||||
ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked)
|
||||
ev := getEvent(t, accountID, manager, activity.SetupKeyRevoked)
|
||||
|
||||
assert.NotNil(t, ev)
|
||||
assert.Equal(t, account.Id, ev.AccountID)
|
||||
assert.Equal(t, accountID, ev.AccountID)
|
||||
assert.Equal(t, newKeyName, ev.Meta["name"])
|
||||
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
|
||||
assert.NotEmpty(t, ev.Meta["key"])
|
||||
assert.Equal(t, userID, ev.InitiatorID)
|
||||
assert.Equal(t, key.Id, ev.TargetID)
|
||||
|
||||
groupAll, err := account.GetGroupAll()
|
||||
assert.NoError(t, err)
|
||||
groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// saving setup key with All group assigned to auto groups should return error
|
||||
autoGroups = append(autoGroups, groupAll.ID)
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
_, err = manager.SaveSetupKey(context.Background(), accountID, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
@@ -103,31 +103,31 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
|
||||
require.NoError(t, err, "failed to get or create account ID")
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "group_1",
|
||||
AccountID: accountID,
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "group_2",
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "group_2",
|
||||
AccountID: accountID,
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
groupAll, err := account.GetGroupAll()
|
||||
assert.NoError(t, err)
|
||||
groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
@@ -170,7 +170,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
|
||||
for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
|
||||
key, err := manager.CreateSetupKey(context.Background(), accountID, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
|
||||
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)
|
||||
|
||||
if tCase.expectedFailure {
|
||||
@@ -189,10 +189,10 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
tCase.expectedUpdatedAt, tCase.expectedGroups, false)
|
||||
|
||||
// check the corresponding events that should have been generated
|
||||
ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated)
|
||||
ev := getEvent(t, accountID, manager, activity.SetupKeyCreated)
|
||||
|
||||
assert.NotNil(t, ev)
|
||||
assert.Equal(t, account.Id, ev.AccountID)
|
||||
assert.Equal(t, accountID, ev.AccountID)
|
||||
assert.Equal(t, tCase.expectedKeyName, ev.Meta["name"])
|
||||
assert.Equal(t, tCase.expectedType, fmt.Sprint(ev.Meta["type"]))
|
||||
assert.NotEmpty(t, ev.Meta["key"])
|
||||
@@ -208,12 +208,10 @@ func TestGetSetupKeys(t *testing.T) {
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
|
||||
require.NoError(t, err, "failed to get or create account ID")
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
@@ -222,7 +220,7 @@ func TestGetSetupKeys(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
|
||||
ID: "group_2",
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
@@ -384,20 +382,24 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
ID: "groupA",
|
||||
AccountID: account.Id,
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
policy := Policy{
|
||||
ID: "policy",
|
||||
Enabled: true,
|
||||
ID: "policy",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "Rule",
|
||||
PolicyID: "policy",
|
||||
Enabled: true,
|
||||
Sources: []string{"groupA"},
|
||||
Destinations: []string{"group"},
|
||||
Destinations: []string{"groupA"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -68,20 +68,27 @@ func TestSqlite_SaveAccount_Large(t *testing.T) {
|
||||
func runLargeTest(t *testing.T, store Store) {
|
||||
t.Helper()
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
groupALL, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID := "account_id"
|
||||
|
||||
err := newAccountWithId(context.Background(), store, accountID, "testuser", "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
groupAll, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
|
||||
assert.NoError(t, err, "failed to get group All")
|
||||
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
setupKey.AccountID = accountID
|
||||
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
|
||||
assert.NoError(t, err, "failed to save setup key")
|
||||
|
||||
const numPerAccount = 6000
|
||||
for n := 0; n < numPerAccount; n++ {
|
||||
netIP := randomIPv4()
|
||||
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
|
||||
peerID := fmt.Sprintf("%s-peer-%d", accountID, n)
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
ID: peerID,
|
||||
AccountID: accountID,
|
||||
Key: peerID,
|
||||
IP: netIP,
|
||||
Name: peerID,
|
||||
@@ -90,16 +97,21 @@ func runLargeTest(t *testing.T, store Store) {
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
}
|
||||
account.Peers[peerID] = peer
|
||||
group, _ := account.GetGroupAll()
|
||||
group.Peers = append(group.Peers, peerID)
|
||||
user := &User{
|
||||
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
|
||||
AccountID: account.Id,
|
||||
}
|
||||
account.Users[user.Id] = user
|
||||
err = store.AddPeerToAccount(context.Background(), peer)
|
||||
assert.NoError(t, err, "failed to add peer")
|
||||
|
||||
err = store.AddPeerToAllGroup(context.Background(), accountID, peerID)
|
||||
assert.NoError(t, err, "failed to add peer to all group")
|
||||
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: fmt.Sprintf("%s-user-%d", accountID, n),
|
||||
AccountID: accountID,
|
||||
})
|
||||
assert.NoError(t, err, "failed to save user")
|
||||
|
||||
route := &route2.Route{
|
||||
ID: route2.ID(fmt.Sprintf("network-id-%d", n)),
|
||||
AccountID: accountID,
|
||||
Description: "base route",
|
||||
NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)),
|
||||
Network: netip.MustParsePrefix(netIP.String() + "/24"),
|
||||
@@ -107,22 +119,24 @@ func runLargeTest(t *testing.T, store Store) {
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
Groups: []string{groupALL.ID},
|
||||
Groups: []string{groupAll.ID},
|
||||
}
|
||||
account.Routes[route.ID] = route
|
||||
err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route)
|
||||
assert.NoError(t, err, "failed to save route")
|
||||
|
||||
group = &nbgroup.Group{
|
||||
group := &nbgroup.Group{
|
||||
ID: fmt.Sprintf("group-id-%d", n),
|
||||
AccountID: account.Id,
|
||||
AccountID: accountID,
|
||||
Name: fmt.Sprintf("group-id-%d", n),
|
||||
Issued: "api",
|
||||
Peers: nil,
|
||||
}
|
||||
account.Groups[group.ID] = group
|
||||
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||
assert.NoError(t, err, "failed to save group")
|
||||
|
||||
nameserver := &nbdns.NameServerGroup{
|
||||
ID: fmt.Sprintf("nameserver-id-%d", n),
|
||||
AccountID: account.Id,
|
||||
AccountID: accountID,
|
||||
Name: fmt.Sprintf("nameserver-id-%d", n),
|
||||
Description: "",
|
||||
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
|
||||
@@ -132,20 +146,20 @@ func runLargeTest(t *testing.T, store Store) {
|
||||
Enabled: false,
|
||||
SearchDomainsEnabled: false,
|
||||
}
|
||||
account.NameServerGroups[nameserver.ID] = nameserver
|
||||
err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameserver)
|
||||
assert.NoError(t, err, "failed to save nameserver group")
|
||||
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
setupKey, _ = GenerateDefaultSetupKey()
|
||||
setupKey.AccountID = accountID
|
||||
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
|
||||
assert.NoError(t, err, "failed to save setup key")
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 1 {
|
||||
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
|
||||
a, err := store.GetAccount(context.Background(), account.Id)
|
||||
a, err := store.GetAccount(context.Background(), accountID)
|
||||
if a == nil {
|
||||
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
|
||||
}
|
||||
@@ -213,41 +227,53 @@ func TestSqlite_SaveAccount(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "account_id"
|
||||
err = newAccountWithId(context.Background(), store, accountID, "testuser", "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
setupKey.AccountID = accountID
|
||||
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
|
||||
require.NoError(t, err, "failed to save setup key")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
|
||||
ID: "testpeer",
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
AccountID: accountID,
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
})
|
||||
require.NoError(t, err, "failed to save peer")
|
||||
|
||||
accountID2 := "account_id2"
|
||||
err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
|
||||
setupKey, _ = GenerateDefaultSetupKey()
|
||||
account2.SetupKeys[setupKey.Key] = setupKey
|
||||
account2.Peers["testpeer2"] = &nbpeer.Peer{
|
||||
Key: "peerkey2",
|
||||
IP: net.IP{127, 0, 0, 2},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name 2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
setupKey.AccountID = accountID2
|
||||
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
|
||||
require.NoError(t, err, "failed to save setup key")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account2)
|
||||
require.NoError(t, err)
|
||||
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
|
||||
ID: "testpeer2",
|
||||
Key: "peerkey2",
|
||||
AccountID: accountID2,
|
||||
IP: net.IP{127, 0, 0, 2},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name 2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
})
|
||||
require.NoError(t, err, "failed to save peer")
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 2 {
|
||||
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
|
||||
a, err := store.GetAccount(context.Background(), account.Id)
|
||||
a, err := store.GetAccount(context.Background(), accountID)
|
||||
if a == nil {
|
||||
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
|
||||
}
|
||||
@@ -288,36 +314,52 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accountID := "account_id"
|
||||
testUserID := "testuser"
|
||||
user := NewAdminUser(testUserID)
|
||||
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
||||
ID: "testtoken",
|
||||
Name: "test token",
|
||||
}}
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", testUserID, "")
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
account.Users[testUserID] = user
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
err = newAccountWithId(context.Background(), store, accountID, testUserID, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 1 {
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
setupKey.AccountID = accountID
|
||||
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
|
||||
require.NoError(t, err, "failed to save setup key")
|
||||
|
||||
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
|
||||
ID: "testpeer",
|
||||
Key: "peerkey",
|
||||
AccountID: accountID,
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
})
|
||||
require.NoError(t, err, "failed to save peer")
|
||||
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
|
||||
ID: "testtoken",
|
||||
UserID: testUserID,
|
||||
Name: "test token",
|
||||
})
|
||||
require.NoError(t, err, "failed to save personal access token")
|
||||
|
||||
accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
|
||||
require.NoError(t, err, "failed to get all account ids")
|
||||
|
||||
if len(accountIDs) != 1 {
|
||||
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "failed to get account")
|
||||
|
||||
err = store.DeleteAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 0 {
|
||||
accountIDs, err = store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
|
||||
require.NoError(t, err, "failed to get all account ids after DeleteAccount()")
|
||||
|
||||
if len(accountIDs) != 0 {
|
||||
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
|
||||
}
|
||||
|
||||
@@ -400,7 +442,7 @@ func TestSqlite_SavePeer(t *testing.T) {
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
ctx := context.Background()
|
||||
err = store.SavePeer(ctx, account.Id, peer)
|
||||
err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer)
|
||||
assert.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
@@ -416,7 +458,7 @@ func TestSqlite_SavePeer(t *testing.T) {
|
||||
updatedPeer.Status.Connected = false
|
||||
updatedPeer.Meta.Hostname = "updatedpeer"
|
||||
|
||||
err = store.SavePeer(ctx, account.Id, updatedPeer)
|
||||
err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, updatedPeer)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(context.Background(), account.Id)
|
||||
@@ -442,7 +484,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||
|
||||
// save status of non-existing peer
|
||||
newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}
|
||||
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
|
||||
err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus)
|
||||
assert.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
@@ -461,7 +503,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
|
||||
err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(context.Background(), account.Id)
|
||||
@@ -472,7 +514,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||
|
||||
newStatus.Connected = true
|
||||
|
||||
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
|
||||
err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(context.Background(), account.Id)
|
||||
@@ -507,7 +549,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
}
|
||||
// error is expected as peer is not in store yet
|
||||
err = store.SavePeerLocation(account.Id, peer)
|
||||
err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer)
|
||||
assert.Error(t, err)
|
||||
|
||||
account.Peers[peer.ID] = peer
|
||||
@@ -519,7 +561,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
|
||||
peer.Location.CityName = "Berlin"
|
||||
peer.Location.GeoNameID = 2950159
|
||||
|
||||
err = store.SavePeerLocation(account.Id, account.Peers[peer.ID])
|
||||
err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, account.Peers[peer.ID])
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(context.Background(), account.Id)
|
||||
@@ -529,7 +571,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
|
||||
assert.Equal(t, peer.Location, actual)
|
||||
|
||||
peer.ID = "non-existing-peer"
|
||||
err = store.SavePeerLocation(account.Id, peer)
|
||||
err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer)
|
||||
assert.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
@@ -572,11 +614,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
|
||||
hashed := "SoMeHaShEdToKeN"
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
|
||||
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, token)
|
||||
require.Equal(t, id, pat.ID)
|
||||
|
||||
_, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash")
|
||||
_, err = store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "non-existing-hash")
|
||||
require.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
@@ -595,11 +637,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
||||
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, user.PATs[id].ID)
|
||||
require.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id)
|
||||
|
||||
_, err = store.GetUserByTokenID(context.Background(), "non-existing-id")
|
||||
_, err = store.GetUserByPATID(context.Background(), LockingStrengthShare, "non-existing-id")
|
||||
require.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
@@ -714,19 +756,28 @@ func newSqliteStore(t *testing.T) *SqlStore {
|
||||
}
|
||||
|
||||
func newAccount(store Store, id int) error {
|
||||
str := fmt.Sprintf("%s-%d", uuid.New().String(), id)
|
||||
account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com")
|
||||
accountID := fmt.Sprintf("%s-%d", uuid.New().String(), id)
|
||||
userID := accountID + "-testuser"
|
||||
|
||||
err := newAccountWithId(context.Background(), store, accountID, userID, "example.com")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["p"+str] = &nbpeer.Peer{
|
||||
Key: "peerkey" + str,
|
||||
setupKey.AccountID = accountID
|
||||
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{
|
||||
Key: accountID + "-peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
return store.SaveAccount(context.Background(), account)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostgresql_NewStore(t *testing.T) {
|
||||
@@ -754,39 +805,56 @@ func TestPostgresql_SaveAccount(t *testing.T) {
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
accountID := "account_id"
|
||||
|
||||
err = newAccountWithId(context.Background(), store, accountID, "testuser", "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
setupKey.AccountID = accountID
|
||||
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
|
||||
require.NoError(t, err, "failed to save setup key")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
|
||||
ID: "testpeer",
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
AccountID: accountID,
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
})
|
||||
require.NoError(t, err, "failed to save peer")
|
||||
|
||||
accountID2 := "account_id2"
|
||||
|
||||
err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
|
||||
setupKey, _ = GenerateDefaultSetupKey()
|
||||
account2.SetupKeys[setupKey.Key] = setupKey
|
||||
account2.Peers["testpeer2"] = &nbpeer.Peer{
|
||||
Key: "peerkey2",
|
||||
IP: net.IP{127, 0, 0, 2},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name 2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
setupKey.AccountID = accountID2
|
||||
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
|
||||
require.NoError(t, err, "failed to save setup key")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account2)
|
||||
require.NoError(t, err)
|
||||
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
|
||||
ID: "testpeer2",
|
||||
Key: "peerkey2",
|
||||
AccountID: accountID2,
|
||||
IP: net.IP{127, 0, 0, 2},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name 2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
})
|
||||
require.NoError(t, err, "failed to save peer")
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 2 {
|
||||
accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthUpdate)
|
||||
require.NoError(t, err, "failed to get all account ids")
|
||||
|
||||
if len(accountIDs) != 2 {
|
||||
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
|
||||
a, err := store.GetAccount(context.Background(), account.Id)
|
||||
a, err := store.GetAccount(context.Background(), accountID)
|
||||
if a == nil {
|
||||
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
|
||||
}
|
||||
@@ -827,32 +895,51 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accountID := "account_id"
|
||||
testUserID := "testuser"
|
||||
|
||||
user := NewAdminUser(testUserID)
|
||||
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
|
||||
ID: "testtoken",
|
||||
Name: "test token",
|
||||
}}
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", testUserID, "")
|
||||
err = newAccountWithId(context.Background(), store, accountID, testUserID, "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
setupKey, _ := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
account.Users[testUserID] = user
|
||||
setupKey.AccountID = accountID
|
||||
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
|
||||
require.NoError(t, err, "failed to save setup key")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
|
||||
ID: "testingpeer",
|
||||
AccountID: accountID,
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
})
|
||||
require.NoError(t, err, "failed to save peer")
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 1 {
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
|
||||
ID: "testtoken",
|
||||
UserID: testUserID,
|
||||
Name: "test token",
|
||||
})
|
||||
require.NoError(t, err, "failed to save personal access token")
|
||||
|
||||
accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthUpdate)
|
||||
require.NoError(t, err, "failed to get all account ids")
|
||||
|
||||
if len(accountIDs) != 1 {
|
||||
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "failed to get account")
|
||||
|
||||
err = store.DeleteAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -908,7 +995,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
|
||||
|
||||
// save status of non-existing peer
|
||||
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
|
||||
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
|
||||
err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus)
|
||||
assert.Error(t, err)
|
||||
|
||||
// save new status of existing peer
|
||||
@@ -924,7 +1011,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
|
||||
err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(context.Background(), account.Id)
|
||||
@@ -967,9 +1054,9 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
||||
hashed := "SoMeHaShEdToKeN"
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
|
||||
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, token)
|
||||
require.Equal(t, id, pat.ID)
|
||||
}
|
||||
|
||||
func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
@@ -984,7 +1071,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
||||
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, user.PATs[id].ID)
|
||||
}
|
||||
@@ -1047,7 +1134,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
labels, err := store.GetAccountPeerDNSLabels(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{}, labels)
|
||||
|
||||
@@ -1059,7 +1146,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
err = store.AddPeerToAccount(context.Background(), peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
labels, err = store.GetAccountPeerDNSLabels(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"peer1.domain.test"}, labels)
|
||||
|
||||
@@ -1071,7 +1158,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
||||
err = store.AddPeerToAccount(context.Background(), peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
labels, err = store.GetAccountPeerDNSLabels(context.Background(), LockingStrengthShare, existingAccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
|
||||
}
|
||||
@@ -1181,7 +1268,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
t.Fatal("failed to save group")
|
||||
return err
|
||||
}
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
|
||||
if err != nil {
|
||||
t.Fatal("failed to get group")
|
||||
return err
|
||||
@@ -1201,7 +1288,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
users, err := store.GetAccountUsers(context.Background(), accountID)
|
||||
users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, len(account.Users))
|
||||
}
|
||||
@@ -1218,7 +1305,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
|
||||
domain := "example.com"
|
||||
category := "public"
|
||||
IsDomainPrimaryAccount := false
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount)
|
||||
require.NoError(t, err)
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
@@ -1232,7 +1319,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
|
||||
domain := "test.com"
|
||||
category := "private"
|
||||
IsDomainPrimaryAccount := true
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount)
|
||||
require.NoError(t, err)
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err)
|
||||
@@ -1246,7 +1333,9 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
|
||||
domain := "test.com"
|
||||
category := "private"
|
||||
IsDomainPrimaryAccount := true
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount)
|
||||
err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, "non-existing-account-id",
|
||||
domain, category, &IsDomainPrimaryAccount,
|
||||
)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
@@ -1274,7 +1363,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID)
|
||||
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID)
|
||||
@@ -1290,6 +1379,6 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
nonExistingKeyID := "non-existing-key-id"
|
||||
|
||||
err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID)
|
||||
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -83,8 +83,8 @@ func NewPeerNotFoundError(peerKey string) error {
|
||||
}
|
||||
|
||||
// NewAccountNotFoundError creates a new Error with NotFound type for a missing account
|
||||
func NewAccountNotFoundError(accountKey string) error {
|
||||
return Errorf(NotFound, "account not found: %s", accountKey)
|
||||
func NewAccountNotFoundError() error {
|
||||
return Errorf(NotFound, "account not found")
|
||||
}
|
||||
|
||||
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
|
||||
@@ -102,23 +102,38 @@ func NewPeerLoginExpiredError() error {
|
||||
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
|
||||
}
|
||||
|
||||
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
||||
func NewSetupKeyNotFoundError(err error) error {
|
||||
return Errorf(NotFound, "setup key not found: %s", err)
|
||||
}
|
||||
|
||||
func NewGetAccountFromStoreError(err error) error {
|
||||
return Errorf(Internal, "issue getting account from store: %s", err)
|
||||
}
|
||||
|
||||
func NewUnauthorizedToViewAccountSettingError() error {
|
||||
return Errorf(PermissionDenied, "only users with admin power can view account settings")
|
||||
}
|
||||
|
||||
// NewUserNotPartOfAccountError creates a new Error with PermissionDenied type for a user not being part of an account
|
||||
func NewUserNotPartOfAccountError() error {
|
||||
return Errorf(PermissionDenied, "user is not part of this account")
|
||||
}
|
||||
|
||||
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
|
||||
func NewGetUserFromStoreError() error {
|
||||
return Errorf(Internal, "issue getting user from store")
|
||||
}
|
||||
|
||||
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
|
||||
func NewStoreContextCanceledError(duration time.Duration) error {
|
||||
return Errorf(Internal, "store access: context canceled after %v", duration)
|
||||
// NewAdminPermissionError creates a new Error with PermissionDenied type for actions requiring admin role.
|
||||
func NewAdminPermissionError() error {
|
||||
return Errorf(PermissionDenied, "admin role required to perform this action")
|
||||
}
|
||||
|
||||
// NewOwnerDeletePermissionError creates a new Error with PermissionDenied type for attempting
|
||||
// to delete a user with the owner role.
|
||||
func NewOwnerDeletePermissionError() error {
|
||||
return Errorf(PermissionDenied, "can't delete a user with the owner role")
|
||||
}
|
||||
|
||||
// NewServiceUserRoleInvalidError creates a new Error with InvalidArgument type for creating a service user with owner role
|
||||
func NewServiceUserRoleInvalidError() error {
|
||||
return Errorf(InvalidArgument, "can't create a service user with owner role")
|
||||
}
|
||||
|
||||
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
|
||||
@@ -126,7 +141,24 @@ func NewInvalidKeyIDError() error {
|
||||
return Errorf(InvalidArgument, "invalid key ID")
|
||||
}
|
||||
|
||||
// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key
|
||||
func NewUnauthorizedToViewSetupKeysError() error {
|
||||
return Errorf(Unauthorized, "only users with admin power can view setup keys")
|
||||
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
||||
func NewSetupKeyNotFoundError(err error) error {
|
||||
return Errorf(NotFound, "setup key not found: %s", err)
|
||||
}
|
||||
|
||||
func NewPATNotFoundError() error {
|
||||
return Errorf(NotFound, "PAT not found")
|
||||
}
|
||||
|
||||
func NewGetPATFromStoreError() error {
|
||||
return Errorf(Internal, "issue getting pat from store")
|
||||
}
|
||||
|
||||
func NewUnauthorizedToViewNSGroupsError() error {
|
||||
return Errorf(PermissionDenied, "only users with admin power can view name server groups")
|
||||
}
|
||||
|
||||
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
|
||||
func NewStoreContextCanceledError(duration time.Duration) error {
|
||||
return Errorf(Internal, "store access: context canceled after %v", duration)
|
||||
}
|
||||
|
||||
@@ -47,65 +47,95 @@ type Store interface {
|
||||
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
|
||||
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
|
||||
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
|
||||
GetAllAccountIDs(ctx context.Context, lockStrength LockingStrength) ([]string, error)
|
||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountIDByUserID(userID string) (string, error)
|
||||
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
|
||||
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
|
||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
|
||||
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error
|
||||
CreateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain *bool) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
|
||||
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error
|
||||
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error)
|
||||
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
||||
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
||||
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
||||
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
||||
|
||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
|
||||
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
|
||||
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
GetAccountPeerDNSLabels(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
||||
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
|
||||
DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
|
||||
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
|
||||
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
|
||||
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
|
||||
|
||||
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
|
||||
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
|
||||
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
|
||||
|
||||
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) (*dns.NameServerGroup, error)
|
||||
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
|
||||
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
|
||||
|
||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
|
||||
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
|
||||
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error)
|
||||
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
|
||||
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
|
||||
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
||||
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||
|
||||
GetInstallationID() string
|
||||
@@ -124,7 +154,6 @@ type Store interface {
|
||||
// This is also a method of metrics.DataSource interface.
|
||||
GetStoreEngine() StoreEngine
|
||||
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
|
||||
DeleteSetupKey(ctx context.Context, accountID, keyID string) error
|
||||
}
|
||||
|
||||
type StoreEngine string
|
||||
|
||||
@@ -32,4 +32,7 @@ INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-3465300
|
||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfvprsrlo1hqoo49ohog", "cg3161rlo1hs9cq94gdg", "cg05lnblo1hkg2j514p0"]',0,'');
|
||||
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
|
||||
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -43,37 +43,34 @@ const (
|
||||
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
|
||||
newPAT, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when adding PAT to user: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, pat.CreatedBy, mockUserID)
|
||||
assert.Equal(t, newPAT.CreatedBy, mockUserID)
|
||||
|
||||
tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken)
|
||||
pat, err := am.Store.GetPATByHashedToken(context.Background(), LockingStrengthShare, newPAT.HashedToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting token ID by hashed token: %s", err)
|
||||
}
|
||||
|
||||
if tokenID == "" {
|
||||
if pat.ID == "" {
|
||||
t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
|
||||
}
|
||||
|
||||
assert.Equal(t, pat.ID, tokenID)
|
||||
assert.Equal(t, newPAT.ID, pat.ID)
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
|
||||
user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, pat.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting user by token ID: %s", err)
|
||||
}
|
||||
@@ -84,15 +81,16 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: mockTargetUserId,
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
})
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -106,15 +104,16 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockTargetUserId] = &User{
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: mockTargetUserId,
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: true,
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
})
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -132,12 +131,9 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -151,12 +147,9 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -164,26 +157,22 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
}
|
||||
|
||||
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn)
|
||||
assert.Errorf(t, err, "Wrong expiration should thorw error")
|
||||
assert.Errorf(t, err, "Wrong expiration should throw error")
|
||||
}
|
||||
|
||||
func TestUser_DeletePAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
mockTokenID1: {
|
||||
ID: mockTokenID1,
|
||||
HashedToken: mockToken1,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
|
||||
ID: mockTokenID1,
|
||||
UserID: mockUserID,
|
||||
HashedToken: mockToken1,
|
||||
})
|
||||
assert.NoError(t, err, "failed to create PAT")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -195,7 +184,7 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
t.Fatalf("Error when adding PAT to user: %s", err)
|
||||
}
|
||||
|
||||
account, err = store.GetAccount(context.Background(), mockAccountID)
|
||||
account, err := store.GetAccount(context.Background(), mockAccountID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting account: %s", err)
|
||||
}
|
||||
@@ -206,21 +195,16 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
func TestUser_GetPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
mockTokenID1: {
|
||||
ID: mockTokenID1,
|
||||
HashedToken: mockToken1,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
|
||||
ID: mockTokenID1,
|
||||
UserID: mockUserID,
|
||||
HashedToken: mockToken1,
|
||||
})
|
||||
assert.NoError(t, err, "failed to create PAT")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -239,25 +223,23 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
func TestUser_GetAllPATs(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockUserID] = &User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
mockTokenID1: {
|
||||
ID: mockTokenID1,
|
||||
HashedToken: mockToken1,
|
||||
},
|
||||
mockTokenID2: {
|
||||
ID: mockTokenID2,
|
||||
HashedToken: mockToken2,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
|
||||
ID: mockTokenID1,
|
||||
UserID: mockUserID,
|
||||
HashedToken: mockToken1,
|
||||
})
|
||||
assert.NoError(t, err, "failed to create PAT")
|
||||
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
|
||||
ID: mockTokenID2,
|
||||
UserID: mockUserID,
|
||||
HashedToken: mockToken2,
|
||||
})
|
||||
assert.NoError(t, err, "failed to create PAT")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -342,12 +324,9 @@ func validateStruct(s interface{}) (err error) {
|
||||
func TestUser_CreateServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -359,7 +338,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
t.Fatalf("Error when creating service user: %s", err)
|
||||
}
|
||||
|
||||
account, err = store.GetAccount(context.Background(), mockAccountID)
|
||||
account, err := store.GetAccount(context.Background(), mockAccountID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, len(account.Users))
|
||||
@@ -383,12 +362,9 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -406,7 +382,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
t.Fatalf("Error when creating user: %s", err)
|
||||
}
|
||||
|
||||
account, err = store.GetAccount(context.Background(), mockAccountID)
|
||||
account, err := store.GetAccount(context.Background(), mockAccountID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, user.IsServiceUser)
|
||||
@@ -425,12 +401,9 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -450,12 +423,9 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
func TestUser_InviteNewUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -549,13 +519,12 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = tt.serviceUser
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, tt.serviceUser)
|
||||
assert.NoError(t, err, "failed to create service user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -582,12 +551,9 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -603,39 +569,38 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "user2username",
|
||||
}
|
||||
targetId = "user3"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
targetId = "user4"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedIntegration,
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
targetId = "user5"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleOwner,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
|
||||
{
|
||||
Id: "user2",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "user2username",
|
||||
},
|
||||
{
|
||||
Id: "user3",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
},
|
||||
{
|
||||
Id: "user4",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedIntegration,
|
||||
},
|
||||
{
|
||||
Id: "user5",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleOwner,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err, "failed to save users")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -685,61 +650,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "user2username",
|
||||
}
|
||||
targetId = "user3"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
targetId = "user4"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedIntegration,
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
targetId = "user5"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleOwner,
|
||||
}
|
||||
account.Users["user6"] = &User{
|
||||
Id: "user6",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
account.Users["user7"] = &User{
|
||||
Id: "user7",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
account.Users["user8"] = &User{
|
||||
Id: "user8",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
}
|
||||
account.Users["user9"] = &User{
|
||||
Id: "user9",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
|
||||
{
|
||||
Id: "user2",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "user2username",
|
||||
},
|
||||
{
|
||||
Id: "user3",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
},
|
||||
{
|
||||
Id: "user4",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedIntegration,
|
||||
},
|
||||
{
|
||||
Id: "user5",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleOwner,
|
||||
},
|
||||
{
|
||||
Id: "user6",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
},
|
||||
{
|
||||
Id: "user7",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
},
|
||||
{
|
||||
Id: "user8",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
},
|
||||
{
|
||||
Id: "user9",
|
||||
AccountID: mockAccountID,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -816,7 +784,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
acc, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
acc, err := am.Store.GetAccount(context.Background(), mockAccountID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, id := range tc.expectedDeleted {
|
||||
@@ -836,12 +804,9 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -865,14 +830,19 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = NewRegularUser("normal_user1")
|
||||
account.Users["normal_user2"] = NewRegularUser("normal_user2")
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
newUser := NewRegularUser("normal_user1")
|
||||
newUser.AccountID = mockAccountID
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
newUser = NewRegularUser("normal_user2")
|
||||
newUser.AccountID = mockAccountID
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -946,15 +916,24 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
|
||||
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
delete(account.Users, mockUserID)
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
newUser := NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
settings, err := store.GetAccountSettings(context.Background(), LockingStrengthShare, mockAccountID)
|
||||
assert.NoError(t, err, "failed to get account settings")
|
||||
|
||||
settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
|
||||
err = store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, mockAccountID, settings)
|
||||
assert.NoError(t, err, "failed to save account settings")
|
||||
|
||||
err = store.DeleteUser(context.Background(), LockingStrengthUpdate, mockAccountID, mockUserID)
|
||||
assert.NoError(t, err, "failed to delete user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -968,7 +947,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
|
||||
assert.Equal(t, 1, len(users))
|
||||
|
||||
userInfo, _ := users[0].ToUserInfo(nil, account.Settings)
|
||||
userInfo, _ := users[0].ToUserInfo(nil, settings)
|
||||
assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView)
|
||||
})
|
||||
}
|
||||
@@ -978,22 +957,21 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
|
||||
func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
externalUser := &User{
|
||||
Id: "externalUser",
|
||||
Role: UserRoleUser,
|
||||
Issued: UserIssuedIntegration,
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: "externalUser",
|
||||
AccountID: mockAccountID,
|
||||
Role: UserRoleUser,
|
||||
Issued: UserIssuedIntegration,
|
||||
IntegrationReference: integration_reference.IntegrationReference{
|
||||
ID: 1,
|
||||
IntegrationType: "external",
|
||||
},
|
||||
}
|
||||
account.Users[externalUser.Id] = externalUser
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
})
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -1013,6 +991,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cacheManager := am.GetExternalCacheManager()
|
||||
|
||||
externalUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, "externalUser")
|
||||
assert.NoError(t, err, "failed to get user")
|
||||
|
||||
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
|
||||
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"})
|
||||
assert.NoError(t, err)
|
||||
@@ -1042,17 +1024,17 @@ func TestUser_IsAdmin(t *testing.T) {
|
||||
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: mockServiceUserID,
|
||||
AccountID: mockAccountID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
})
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -1071,17 +1053,16 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
account.Users[mockServiceUserID] = &User{
|
||||
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
|
||||
Id: mockServiceUserID,
|
||||
AccountID: mockAccountID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
})
|
||||
assert.NoError(t, err, "failed to create user")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -1240,21 +1221,30 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
// create an account and an admin user
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io")
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), ownerUserID, "netbird.io")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create other users
|
||||
account.Users[regularUserID] = NewRegularUser(regularUserID)
|
||||
account.Users[adminUserID] = NewAdminUser(adminUserID)
|
||||
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
regularUser := NewRegularUser(regularUserID)
|
||||
regularUser.AccountID = accountID
|
||||
|
||||
adminUser := NewAdminUser(adminUserID)
|
||||
adminUser.AccountID = accountID
|
||||
|
||||
serviceUser := &User{
|
||||
Id: serviceUserID,
|
||||
AccountID: accountID,
|
||||
IsServiceUser: true,
|
||||
Role: UserRoleAdmin,
|
||||
ServiceUserName: "service",
|
||||
}
|
||||
|
||||
updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update)
|
||||
err = manager.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{regularUser, adminUser, serviceUser})
|
||||
assert.NoError(t, err, "failed to save users")
|
||||
|
||||
updated, err := manager.SaveUser(context.Background(), accountID, tc.initiatorID, tc.update)
|
||||
if tc.expectedErr {
|
||||
require.Errorf(t, err, "expecting SaveUser to throw an error")
|
||||
} else {
|
||||
|
||||
@@ -88,18 +88,18 @@ type Route struct {
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `gorm:"index"`
|
||||
// Network and Domains are mutually exclusive
|
||||
Network netip.Prefix `gorm:"serializer:json"`
|
||||
Domains domain.List `gorm:"serializer:json"`
|
||||
KeepRoute bool
|
||||
NetID NetID
|
||||
Description string
|
||||
Peer string
|
||||
PeerGroups []string `gorm:"serializer:json"`
|
||||
NetworkType NetworkType
|
||||
Masquerade bool
|
||||
Metric int
|
||||
Enabled bool
|
||||
Groups []string `gorm:"serializer:json"`
|
||||
Network netip.Prefix `gorm:"serializer:json"`
|
||||
Domains domain.List `gorm:"serializer:json"`
|
||||
KeepRoute bool
|
||||
NetID NetID
|
||||
Description string
|
||||
Peer string
|
||||
PeerGroups []string `gorm:"serializer:json"`
|
||||
NetworkType NetworkType
|
||||
Masquerade bool
|
||||
Metric int
|
||||
Enabled bool
|
||||
Groups []string `gorm:"serializer:json"`
|
||||
AccessControlGroups []string `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
@@ -111,19 +111,20 @@ func (r *Route) EventMeta() map[string]any {
|
||||
// Copy copies a route object
|
||||
func (r *Route) Copy() *Route {
|
||||
route := &Route{
|
||||
ID: r.ID,
|
||||
Description: r.Description,
|
||||
NetID: r.NetID,
|
||||
Network: r.Network,
|
||||
Domains: slices.Clone(r.Domains),
|
||||
KeepRoute: r.KeepRoute,
|
||||
NetworkType: r.NetworkType,
|
||||
Peer: r.Peer,
|
||||
PeerGroups: slices.Clone(r.PeerGroups),
|
||||
Metric: r.Metric,
|
||||
Masquerade: r.Masquerade,
|
||||
Enabled: r.Enabled,
|
||||
Groups: slices.Clone(r.Groups),
|
||||
ID: r.ID,
|
||||
AccountID: r.AccountID,
|
||||
Description: r.Description,
|
||||
NetID: r.NetID,
|
||||
Network: r.Network,
|
||||
Domains: slices.Clone(r.Domains),
|
||||
KeepRoute: r.KeepRoute,
|
||||
NetworkType: r.NetworkType,
|
||||
Peer: r.Peer,
|
||||
PeerGroups: slices.Clone(r.PeerGroups),
|
||||
Metric: r.Metric,
|
||||
Masquerade: r.Masquerade,
|
||||
Enabled: r.Enabled,
|
||||
Groups: slices.Clone(r.Groups),
|
||||
AccessControlGroups: slices.Clone(r.AccessControlGroups),
|
||||
}
|
||||
return route
|
||||
@@ -138,6 +139,7 @@ func (r *Route) IsEqual(other *Route) bool {
|
||||
}
|
||||
|
||||
return other.ID == r.ID &&
|
||||
other.AccountID == r.AccountID &&
|
||||
other.Description == r.Description &&
|
||||
other.NetID == r.NetID &&
|
||||
other.Network == r.Network &&
|
||||
@@ -149,7 +151,7 @@ func (r *Route) IsEqual(other *Route) bool {
|
||||
other.Masquerade == r.Masquerade &&
|
||||
other.Enabled == r.Enabled &&
|
||||
slices.Equal(r.Groups, other.Groups) &&
|
||||
slices.Equal(r.PeerGroups, other.PeerGroups)&&
|
||||
slices.Equal(r.PeerGroups, other.PeerGroups) &&
|
||||
slices.Equal(r.AccessControlGroups, other.AccessControlGroups)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user