Compare commits

...

78 Commits

Author SHA1 Message Date
bcmmbaga
a23a09bba3 Fix failed to create policy and delete user PAT on postgres
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-07 18:34:07 +03:00
bcmmbaga
2f7027194b Remove code duplicate on peer
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-07 12:56:54 +03:00
bcmmbaga
197d844a16 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-07 11:39:24 +03:00
bcmmbaga
df6c9a528a Refactor UpdatePeer method to defer event logging and scheduling until after peer save
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-06 16:01:07 +03:00
bcmmbaga
9cb7336ef5 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-06 15:59:12 +03:00
bcmmbaga
e513e51e9f Handle new account creation directly within the store
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-04 14:43:14 +03:00
bcmmbaga
4ad00e784c Remove redundant accounts All group check on startup
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-03 18:49:16 +03:00
bcmmbaga
bfeb7f0875 Refactor users updating
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-03 01:14:26 +03:00
bcmmbaga
dde01b8e02 Refactor user and peers delete
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-02 16:05:12 +03:00
bcmmbaga
74246d18ba Merge branch 'main' into refactor/get-account-usage
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-02 06:41:34 +03:00
bcmmbaga
fa5db7d7ee Refactor service user handling, user cache lookup, and cache loading
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-01 20:27:52 +03:00
bcmmbaga
fed48de83f Refactor auth middleware
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-01 14:02:09 +03:00
bcmmbaga
e73b5da42b Refactor update account peers
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-31 22:30:13 +03:00
bcmmbaga
8cacdae70c Merge branch 'main' into refactor/get-account-usage
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-31 21:59:09 +03:00
bcmmbaga
6b94f6e4e7 Refactor ephemeral peers and mark PAT as used
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-31 21:50:05 +03:00
bcmmbaga
b7525d9fe8 Merge branch 'main' into refactor/get-account-usage 2024-10-30 22:36:47 +03:00
bcmmbaga
901d283114 Merge branch 'main' into refactor-get-account-by-token
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-30 22:34:59 +03:00
bcmmbaga
7278a21b0d refactor get account in peers
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-29 13:50:44 +03:00
bcmmbaga
9bf0bf4843 wip: refactor get account in peers
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-28 17:47:54 +03:00
bcmmbaga
313e158e20 Refactor route
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-25 13:18:24 +03:00
bcmmbaga
0bdcb41e20 Refactor peer expiry, inactivity, location and status update to remove get account
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-23 19:03:48 +03:00
bcmmbaga
97dbdd7940 fix group tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-18 10:48:28 +03:00
bcmmbaga
a82b5ce80e Merge branch 'main' into refactor/get-account-usage
# Conflicts:
#	management/server/account.go
2024-10-17 22:01:26 +03:00
bcmmbaga
83be99c849 refactor get peers posture checks
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-17 21:58:34 +03:00
bcmmbaga
ee96a81b83 fix handler tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-17 16:34:44 +03:00
bcmmbaga
b0edc5f1f7 Merge branch 'main' into refactor/get-account-usage
# Conflicts:
#	management/server/sql_store.go
2024-10-17 16:10:16 +03:00
bcmmbaga
408d0cd504 Refactor policy save and delete
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-17 14:11:22 +03:00
bcmmbaga
b66f331711 get the first element when get record by ID
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-17 14:10:01 +03:00
bcmmbaga
d7a6996bed check user accounts for setup keys
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-17 11:59:46 +03:00
bcmmbaga
d7c63d5c04 Remove get account from groups ops
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-16 16:04:34 +03:00
bcmmbaga
1123729c1c fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-15 18:17:47 +03:00
bcmmbaga
a8c8b77df8 Merge branch 'main' into refactor/get-account-usage
# Conflicts:
#	management/server/account.go
#	management/server/file_store.go
#	management/server/peer.go
#	management/server/policy.go
#	management/server/route.go
#	management/server/sql_store.go
#	management/server/store.go
#	management/server/user.go
2024-10-14 14:31:55 +03:00
bcmmbaga
0297b5f142 wip: refactoring
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-02 11:56:47 +03:00
bcmmbaga
78e238646c refactor groups methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 16:32:31 +03:00
bcmmbaga
f9ed25f8b1 wip refactor peer methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 01:07:48 +03:00
bcmmbaga
f43a006c34 Fix posture check name uniqueness per account
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 01:06:52 +03:00
bcmmbaga
1a37b12d1b refactor user PAT
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 00:55:32 +03:00
bcmmbaga
d36d30dec4 refactor name server groups
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 00:54:53 +03:00
bcmmbaga
43eb7261e3 refactor account and dns settings
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-01 00:54:28 +03:00
bcmmbaga
9e47c94a7f refactor setup keys
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-30 14:02:55 +03:00
bcmmbaga
edf67672ad fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-27 22:31:26 +03:00
bcmmbaga
bc520412ba Merge branch 'main' into refactor/get-account-usage
# Conflicts:
#	management/server/file_store.go
#	management/server/http/posture_checks_handler.go
#	management/server/mock_server/account_mock.go
#	management/server/policy.go
#	management/server/sql_store.go
#	management/server/store.go
2024-09-27 20:27:05 +03:00
bcmmbaga
d87fe0257b Merge branch 'refactor-get-account-by-token' into refactor/get-account-usage 2024-09-26 19:48:17 +03:00
bcmmbaga
b1b2b0adf0 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-26 19:47:43 +03:00
bcmmbaga
96f18c2c8c fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-26 19:46:37 +03:00
bcmmbaga
73be8c8a32 fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-26 19:00:59 +03:00
bcmmbaga
f61c914fd7 Merge branch 'refactor-get-account-by-token' into refactor/get-account-usage
# Conflicts:
#	management/server/file_store.go
2024-09-26 18:51:47 +03:00
bcmmbaga
4575ae2841 add store lock
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-26 18:46:23 +03:00
bcmmbaga
ca6a9fd602 Merge branch 'refactor-get-account-by-token' into refactor/get-account-usage 2024-09-26 16:39:52 +03:00
bcmmbaga
871595d15f Merge branch 'main' into refactor-get-account-by-token
# Conflicts:
#	management/server/sql_store.go
2024-09-26 16:39:17 +03:00
bcmmbaga
30253b0565 Merge branch 'refactor-get-account-by-token' into refactor/get-account-usage 2024-09-26 16:34:36 +03:00
bcmmbaga
dc82c2d1ce fix add missing policy source posture checks
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-26 16:34:19 +03:00
bcmmbaga
3b4bcdf5a4 refactor posture checks save and deletion
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-26 16:28:49 +03:00
bcmmbaga
87c8430e99 add store policy save and method
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-25 22:47:54 +03:00
bcmmbaga
c384874d7d fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-25 15:04:57 +03:00
bcmmbaga
b815393180 fix lint
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-25 13:02:08 +03:00
bcmmbaga
41b212f610 Refactor store
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-25 12:53:20 +03:00
bcmmbaga
16174f0478 Refactor route, setupkey, nameserver and dns to get record(s) from store
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-25 12:52:42 +03:00
bcmmbaga
d14b855670 Refactor user permissions and retrieves PAT
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-24 22:57:04 +03:00
bcmmbaga
eab85644cd Refactor retrieval of policy and posture checks
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-24 21:57:33 +03:00
bcmmbaga
7561706627 add GetGroupByID from store and refactor
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-24 19:55:33 +03:00
bcmmbaga
1ffe89d20d add GetGroupByName from store
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-24 16:36:57 +03:00
bcmmbaga
28840383e1 refactor
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-24 13:30:13 +03:00
bcmmbaga
d9f612d623 remove locks
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-23 20:12:57 +03:00
bcmmbaga
7601a17150 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-22 23:44:10 +03:00
bcmmbaga
8f98adddf6 refactor handlers to use GetAccountIDFromToken
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-22 15:14:31 +03:00
bcmmbaga
26dd045da5 Merge branch 'main' into refactor-get-account-by-token 2024-09-20 14:08:09 +03:00
bcmmbaga
4d9bb7ea35 refactor getAccountWithAuthorizationClaims to return account id
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-20 14:07:44 +03:00
bcmmbaga
9631cb4fb3 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-18 19:05:21 +03:00
bcmmbaga
8f9c54f6c2 remove GetUserByID from account manager
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-18 17:03:04 +03:00
bcmmbaga
f60a4234b1 revert handles change
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-18 16:40:47 +03:00
bcmmbaga
021fc8f33e fix merge
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-18 16:11:20 +03:00
bcmmbaga
a4c4158bcf Merge branch 'main' into refactor-get-account-by-token 2024-09-18 16:03:55 +03:00
bcmmbaga
720d36a290 refactor getAccountWithAuthorizationClaims
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-18 15:55:52 +03:00
bcmmbaga
ccab3b427f refactor getAccountFromToken
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-18 14:24:39 +03:00
bcmmbaga
e5d55d3c10 refactor handlers to get account when necessary
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-17 23:15:54 +03:00
bcmmbaga
3cf1b02f31 refactor jwt groups extractor
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-17 18:18:59 +03:00
bcmmbaga
258b30cf48 refactor access control middleware and user access by JWT groups
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-16 13:33:36 +03:00
49 changed files with 4701 additions and 3535 deletions

View File

@@ -38,7 +38,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
@@ -171,7 +170,7 @@ type Engine struct {
relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
srWatcher *guard.SRWatcher
}
// Peer is an instance of the Connection Peer

View File

@@ -28,9 +28,12 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun
CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0);
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,'');
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
INSERT INTO installations VALUES(1,'');
COMMIT;

View File

@@ -136,6 +136,7 @@ func ParseNameServerURL(nsURL string) (NameServer, error) {
func (g *NameServerGroup) Copy() *NameServerGroup {
nsGroup := &NameServerGroup{
ID: g.ID,
AccountID: g.AccountID,
Name: g.Name,
Description: g.Description,
NameServers: make([]NameServer, len(g.NameServers)),
@@ -156,6 +157,7 @@ func (g *NameServerGroup) Copy() *NameServerGroup {
// IsEqual compares one nameserver group with the other
func (g *NameServerGroup) IsEqual(other *NameServerGroup) bool {
return other.ID == g.ID &&
other.AccountID == g.AccountID &&
other.Name == g.Name &&
other.Description == g.Description &&
other.Primary == g.Primary &&

File diff suppressed because it is too large Load Diff

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/golang-jwt/jwt"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -397,7 +398,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
for _, testCase := range tt {
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
store := newStore(t)
err := newAccountWithId(context.Background(), store, "account-1", userID, "netbird.io")
require.NoError(t, err, "failed to create account")
account, err := store.GetAccount(context.Background(), "account-1")
require.NoError(t, err, "failed to get account")
account.UpdateSettings(&testCase.accountSettings)
account.Network = network
account.Peers = testCase.peers
@@ -415,6 +423,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
store.Close(context.Background())
}
}
@@ -422,7 +432,15 @@ func TestNewAccount(t *testing.T) {
domain := "netbird.io"
userId := "account_creator"
accountID := "account_id"
account := newAccountWithId(context.Background(), accountID, userId, domain)
store := newStore(t)
defer store.Close(context.Background())
err := newAccountWithId(context.Background(), store, accountID, userId, domain)
require.NoError(t, err, "failed to create account")
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "failed to get account")
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
}
@@ -433,16 +451,16 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
return
}
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
if err != nil {
t.Fatal(err)
}
if account == nil {
if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userID)
return
}
account, err = manager.Store.GetAccountByUser(context.Background(), userID)
account, err := manager.Store.GetAccountByUser(context.Background(), userID)
if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
return
@@ -665,15 +683,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id"
domain := "test.domain"
_ = newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountIDByUserID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
@@ -689,44 +704,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account groups")
require.Len(t, account.Groups, 1, "only ALL group should exists")
require.Len(t, accountGroups, 1, "only ALL group should exists")
})
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
require.NoError(t, err, "failed to update account settings")
accountIDs, err := manager.Store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
require.NoError(t, err, "failed to get account ids")
require.Len(t, accountIDs, 1, "only one account should exist")
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account groups")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
require.Len(t, accountGroups, 1, "if group claim is not set no group added from JWT")
})
t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
require.NoError(t, err, "failed to update account settings")
accountIDs, err := manager.Store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
require.NoError(t, err, "failed to get account ids")
require.Len(t, accountIDs, 1, "only one account should exist")
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to check account existence")
require.True(t, exists, "account should exist")
require.Len(t, account.Groups, 3, "groups should be added to the account")
accountGroups, err := manager.GetAllGroups(context.Background(), accountID, userId)
require.NoError(t, err, "failed to get account groups")
require.Len(t, accountGroups, 3, "groups should be added to the account")
groupsByNames := map[string]*group.Group{}
for _, g := range account.Groups {
for _, g := range accountGroups {
groupsByNames[g.Name] = g
}
@@ -744,60 +768,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store := newStore(t)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
require.NoError(t, err, "failed to create account")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser",
PATs: map[string]*PersonalAccessToken{
"tokenId": {
ID: "tokenId",
HashedToken: encodedHashedToken,
},
},
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
userPAT := &PersonalAccessToken{
ID: "tokenId",
UserID: "testuser",
HashedToken: encodedHashedToken,
CreatedAt: time.Now().UTC(),
}
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
require.NoError(t, err, "failed to save PAT")
am := DefaultAccountManager{
Store: store,
}
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
user, pat, _, _, err := am.GetAccountInfoFromPAT(context.Background(), token)
if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err)
}
assert.Equal(t, "account_id", account.Id)
assert.Equal(t, "someUser", user.Id)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
assert.Equal(t, "account_id", user.AccountID)
assert.Equal(t, "testuser", user.Id)
assert.Equal(t, userPAT, pat)
}
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
require.NoError(t, err, "failed to create account")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
Id: "someUser",
PATs: map[string]*PersonalAccessToken{
"tokenId": {
ID: "tokenId",
HashedToken: encodedHashedToken,
LastUsed: time.Time{},
},
},
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
userPAT := &PersonalAccessToken{
ID: "tokenId",
UserID: "someUser",
HashedToken: encodedHashedToken,
LastUsed: time.Time{},
}
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
require.NoError(t, err, "failed to save PAT")
am := DefaultAccountManager{
Store: store,
@@ -808,11 +825,10 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
t.Fatalf("Error when marking PAT used: %s", err)
}
account, err = am.Store.GetAccount(context.Background(), "account_id")
if err != nil {
t.Fatalf("Error when getting account: %s", err)
}
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
userPAT, err = store.GetPATByID(context.Background(), LockingStrengthShare, userPAT.UserID, userPAT.ID)
require.NoError(t, err, "failed to get PAT")
assert.True(t, !userPAT.LastUsed.IsZero())
}
func TestAccountManager_PrivateAccount(t *testing.T) {
@@ -823,15 +839,15 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
}
userId := "test_user"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "")
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, "")
if err != nil {
t.Fatal(err)
}
if account == nil {
if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userId)
}
account, err = manager.Store.GetAccountByUser(context.Background(), userId)
account, err := manager.Store.GetAccountByUser(context.Background(), userId)
if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
}
@@ -850,32 +866,22 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
userId := "test_user"
domain := "hotmail.com"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
if err != nil {
t.Fatal(err)
}
if account == nil {
t.Fatalf("expected to create an account for a user %s", userId)
}
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
require.NoError(t, err, "failed to get or create account by user")
require.NotEmptyf(t, accountID, "expected to create an account for a user %s", userId)
if account != nil && account.Domain != domain {
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain)
}
accDomain, _, err := manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account domain and category")
require.Equal(t, domain, accDomain, "expected account domain to match")
domain = "gmail.com"
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
if err != nil {
t.Fatalf("got the following error while retrieving existing acc: %v", err)
}
accountID, err = manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
require.NoError(t, err, "failed to get or create account by user")
if account == nil {
t.Fatalf("expected to get an account for a user %s", userId)
}
if account != nil && account.Domain != domain {
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
}
accDomain, _, err = manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account domain and category")
require.Equal(t, domain, accDomain, "expected account domain to match")
}
func TestAccountManager_GetAccountByUserID(t *testing.T) {
@@ -907,12 +913,11 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
}
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
account := newAccountWithId(context.Background(), accountID, userID, domain)
err := am.Store.SaveAccount(context.Background(), account)
err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
if err != nil {
return nil, err
}
return account, nil
return am.Store.GetAccount(context.Background(), accountID)
}
func TestAccountManager_GetAccount(t *testing.T) {
@@ -1055,23 +1060,18 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return
}
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud")
if err != nil {
t.Fatal(err)
}
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "netbird.cloud")
require.NoError(t, err, "failed to get or create account by user")
serial := account.Network.CurrentSerial() // should be 0
network, err := manager.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account network")
if account.Network.Serial != 0 {
t.Errorf("expecting account network to have an initial Serial=0")
return
}
serial := network.CurrentSerial() // should be 0
require.Equal(t, 0, int(serial), "expected account network to have an initial Serial=0")
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
require.NoError(t, err, "failed to generate private key")
expectedPeerKey := key.PublicKey().String()
expectedUserID := userID
@@ -1079,16 +1079,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
if err != nil {
t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy)
return
}
require.NoError(t, err, "failed to add peer")
account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
return
}
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "failed to get account")
if peer.Key != expectedPeerKey {
t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key)
@@ -1131,10 +1125,12 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
}
policy := Policy{
ID: "policy",
Enabled: true,
ID: xid.New().String(),
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
@@ -1212,10 +1208,15 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
policyID := xid.New().String()
policy := Policy{
Enabled: true,
ID: policyID,
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule",
PolicyID: policyID,
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
@@ -1249,19 +1250,25 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
group := group.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer3.ID},
ID: "groupA",
AccountID: account.Id,
Name: "GroupA",
Peers: []string{peer1.ID, peer3.ID},
}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
policyID := xid.New().String()
policy := Policy{
Enabled: true,
ID: policyID,
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule",
PolicyID: policyID,
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
@@ -1302,19 +1309,24 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
ID: "groupA",
AccountID: account.Id,
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}
err := manager.SaveGroup(context.Background(), account.Id, userID, &group)
require.NoError(t, err, "failed to save group")
policyID := xid.New().String()
policy := Policy{
Enabled: true,
ID: policyID,
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule",
PolicyID: policyID,
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
@@ -1324,6 +1336,9 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
},
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
@@ -1352,7 +1367,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return
}
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
if err := manager.DeleteGroup(context.Background(), account.Id, userID, group.ID); err != nil {
t.Errorf("delete group: %v", err)
return
}
@@ -1748,18 +1763,9 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
@@ -1774,11 +1780,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
// disable expiration first
update := peer.Copy()
update.LoginExpirationEnabled = false
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "unable to update peer")
// enabling expiration should trigger the routine
update.LoginExpirationEnabled = true
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "unable to update peer")
failed := waitTimeout(wg, time.Second)
@@ -1802,10 +1808,13 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
require.NoError(t, err, "failed to get account settings")
settings.PeerLoginExpirationEnabled = true
settings.PeerLoginExpiration = time.Hour
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{}
@@ -1822,11 +1831,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1854,10 +1860,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
@@ -1871,10 +1874,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
},
}
// enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
settings, err := manager.GetAccountSettings(context.Background(), accountID, userID)
require.NoError(t, err, "failed to get account settings")
settings.PeerLoginExpirationEnabled = true
settings.PeerLoginExpiration = time.Hour
settings, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error")
failed := waitTimeout(wg, time.Second)
@@ -1884,10 +1889,8 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
wg.Add(1)
// disabling PeerLoginExpirationEnabled should trigger cancel
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
settings.PeerLoginExpirationEnabled = false
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error")
failed = waitTimeout(wg, time.Second)
if failed {
@@ -1902,30 +1905,29 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")
settings.PeerLoginExpirationEnabled = false
settings.PeerLoginExpiration = time.Hour
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.NoError(t, err, "expecting to update account settings successfully but got error")
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err = manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled)
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false,
})
settings.PeerLoginExpirationEnabled = false
settings.PeerLoginExpiration = time.Second
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false,
})
settings.PeerLoginExpirationEnabled = false
settings.PeerLoginExpiration = time.Hour * 24 * 181
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings)
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
}
@@ -2606,7 +2608,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
@@ -2626,7 +2628,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
@@ -2665,7 +2667,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")

View File

@@ -6,6 +6,7 @@ import (
"strconv"
"sync"
nbgroup "github.com/netbirdio/netbird/management/server/group"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
@@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
@@ -94,64 +99,105 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
}
if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return status.NewAdminPermissionError()
}
oldSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return err
}
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, groups)
if err != nil {
return err
}
}
oldSettings := account.DNSSettings.Copy()
account.DNSSettings = dnsSettingsToSave.Copy()
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
updateAccountPeers, err := am.areDNSSettingChangesAffectPeers(ctx, accountID, addedGroups, removedGroups)
if err != nil {
return fmt.Errorf("failed to check if dns settings changes affect peers: %w", err)
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err)
}
return nil
})
if err != nil {
return err
}
groupMap := make(map[string]*nbgroup.Group, len(groups))
for _, g := range groups {
groupMap[g.ID] = g
}
for _, id := range addedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
group, ok := groupMap[id]
if ok {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
}
}
for _, id := range removedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
group, ok := groupMap[id]
if ok {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
}
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
am.updateAccountPeers(ctx, account)
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func (am *DefaultAccountManager) areDNSSettingChangesAffectPeers(ctx context.Context, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := am.anyGroupHasPeers(ctx, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return am.anyGroupHasPeers(ctx, accountID, removedGroups)
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{

View File

@@ -39,12 +39,12 @@ func TestGetDNSSettings(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestDNSAccount(t, am)
accountID, err := initTestDNSAccount(t, am)
if err != nil {
t.Fatal("failed to init testing account")
}
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
dnsSettings, err := am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID)
if err != nil {
t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
}
@@ -53,16 +53,12 @@ func TestGetDNSSettings(t *testing.T) {
t.Fatal("DNS settings for new accounts shouldn't return nil")
}
account.DNSSettings = DNSSettings{
err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &DNSSettings{
DisabledManagementGroups: []string{group1ID},
}
})
require.NoError(t, err, "failed to update DNS settings")
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Error("failed to save testing account with new DNS settings")
}
dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
dnsSettings, err = am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID)
if err != nil {
t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
}
@@ -71,7 +67,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
}
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
_, err = am.GetDNSSettings(context.Background(), accountID, dnsRegularUserID)
if err == nil {
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
}
@@ -126,12 +122,12 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestDNSAccount(t, am)
accountID, err := initTestDNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings)
err = am.SaveDNSSettings(context.Background(), accountID, testCase.userID, testCase.inputSettings)
if err != nil {
if testCase.shouldFail {
return
@@ -139,7 +135,7 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error(err)
}
updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
updatedAccount, err := am.Store.GetAccount(context.Background(), accountID)
if err != nil {
t.Errorf("should be able to retrieve updated account, got err: %s", err)
}
@@ -158,17 +154,17 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestDNSAccount(t, am)
accountID, err := initTestDNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
peer1, err := account.FindPeerByPubKey(dnsPeer1Key)
peer1, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer1Key)
if err != nil {
t.Error("failed to init testing account")
}
peer2, err := account.FindPeerByPubKey(dnsPeer2Key)
peer2, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer2Key)
if err != nil {
t.Error("failed to init testing account")
}
@@ -179,11 +175,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
dnsSettings := account.DNSSettings.Copy()
accountDNSSettings, err := am.Store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account DNS settings")
dnsSettings := accountDNSSettings.Copy()
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
account.DNSSettings = dnsSettings
err = am.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &dnsSettings)
require.NoError(t, err, "failed to update DNS settings")
updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
require.NoError(t, err)
@@ -222,7 +220,7 @@ func createDNSStore(t *testing.T) (Store, error) {
return store, nil
}
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
t.Helper()
peer1 := &nbpeer.Peer{
Key: dnsPeer1Key,
@@ -257,64 +255,65 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
domain := "example.com"
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
account.Users[dnsRegularUserID] = &User{
Id: dnsRegularUserID,
Role: UserRoleUser,
err := newAccountWithId(context.Background(), am.Store, dnsAccountID, dnsAdminUserID, domain)
if err != nil {
return "", err
}
err := am.Store.SaveAccount(context.Background(), account)
err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: dnsRegularUserID,
AccountID: dnsAccountID,
Role: UserRoleUser,
})
if err != nil {
return nil, err
return "", err
}
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
if err != nil {
return nil, err
return "", err
}
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
if err != nil {
return nil, err
return "", err
}
account, err = am.Store.GetAccount(context.Background(), account.Id)
peer1, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer1.Key)
if err != nil {
return nil, err
return "", err
}
peer1, err = account.FindPeerByPubKey(peer1.Key)
_, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer2.Key)
if err != nil {
return nil, err
return "", err
}
_, err = account.FindPeerByPubKey(peer2.Key)
err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{
{
ID: dnsGroup1ID,
AccountID: dnsAccountID,
Peers: []string{peer1.ID},
Name: dnsGroup1ID,
},
{
ID: dnsGroup2ID,
AccountID: dnsAccountID,
Name: dnsGroup2ID,
},
})
if err != nil {
return nil, err
return "", err
}
newGroup1 := &group.Group{
ID: dnsGroup1ID,
Peers: []string{peer1.ID},
Name: dnsGroup1ID,
}
newGroup2 := &group.Group{
ID: dnsGroup2ID,
Name: dnsGroup2ID,
}
account.Groups[newGroup1.ID] = newGroup1
account.Groups[newGroup2.ID] = newGroup2
allGroup, err := account.GetGroupAll()
allGroup, err := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, dnsAccountID, "All")
if err != nil {
return nil, err
return "", err
}
account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{
ID: dnsNSGroup1,
Name: "ns-group-1",
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &dns.NameServerGroup{
ID: dnsNSGroup1,
AccountID: dnsAccountID,
Name: "ns-group-1",
NameServers: []dns.NameServer{{
IP: netip.MustParseAddr(savedPeer1.IP.String()),
NSType: dns.UDPNameServerType,
@@ -323,14 +322,12 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
Primary: true,
Enabled: true,
Groups: []string{allGroup.ID},
}
err = am.Store.SaveAccount(context.Background(), account)
})
if err != nil {
return nil, err
return "", err
}
return am.Store.GetAccount(context.Background(), account.Id)
return dnsAccountID, nil
}
func generateTestData(size int) nbdns.Config {

View File

@@ -20,10 +20,10 @@ var (
)
type ephemeralPeer struct {
id string
account *Account
deadline time.Time
next *ephemeralPeer
id string
accountID string
deadline time.Time
next *ephemeralPeer
}
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
@@ -104,12 +104,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
return
}
e.peersLock.Lock()
defer e.peersLock.Unlock()
@@ -117,7 +111,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
return
}
e.addPeer(peer.ID, a, newDeadLine())
e.addPeer(peer.AccountID, peer.ID, newDeadLine())
if e.timer == nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx)
@@ -126,17 +120,21 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
}
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
accounts := e.store.GetAllAccounts(context.Background())
peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare)
if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return
}
t := newDeadLine()
count := 0
for _, a := range accounts {
for id, p := range a.Peers {
if p.Ephemeral {
count++
e.addPeer(id, a, t)
}
for _, p := range peers {
if p.Ephemeral {
count++
e.addPeer(p.AccountID, p.ID, t)
}
}
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
}
@@ -170,18 +168,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
for id, p := range deletePeers {
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator)
err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
}
}
}
func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) {
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
ep := &ephemeralPeer{
id: id,
account: account,
deadline: deadline,
id: peerID,
accountID: accountID,
deadline: deadline,
}
if e.headPeer == nil {

View File

@@ -7,25 +7,12 @@ import (
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/require"
)
type MockStore struct {
Store
account *Account
}
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
return []*Account{s.account}
}
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
_, ok := s.account.Peers[peerId]
if ok {
return s.account, nil
}
return nil, status.NewPeerNotFoundError(peerId)
accountID string
}
type MocAccountManager struct {
@@ -33,9 +20,8 @@ type MocAccountManager struct {
store *MockStore
}
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
delete(a.store.account.Peers, peerID)
return nil //nolint:nil
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, _ string) error {
return a.store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID)
}
func TestNewManager(t *testing.T) {
@@ -44,23 +30,26 @@ func TestNewManager(t *testing.T) {
return startTime
}
store := &MockStore{}
store := &MockStore{
Store: newStore(t),
}
am := MocAccountManager{
store: store,
}
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
require.NoError(t, err, "failed to seed peers")
mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
}
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
require.NoError(t, err, "failed to get account peers")
require.Equal(t, numberOfPeers, len(peers), "failed to cleanup ephemeral peers")
}
func TestNewManagerPeerConnected(t *testing.T) {
@@ -69,26 +58,32 @@ func TestNewManagerPeerConnected(t *testing.T) {
return startTime
}
store := &MockStore{}
store := &MockStore{
Store: newStore(t),
}
am := MocAccountManager{
store: store,
}
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
require.NoError(t, err, "failed to seed peers")
mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0")
require.NoError(t, err, "failed to get peer")
mgr.OnPeerConnected(context.Background(), peer)
startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + 1
if len(store.account.Peers) != expected {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
require.NoError(t, err, "failed to get account peers")
require.Equal(t, numberOfPeers+1, len(peers), "failed to cleanup ephemeral peers")
}
func TestNewManagerPeerDisconnected(t *testing.T) {
@@ -97,50 +92,73 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
return startTime
}
store := &MockStore{}
store := &MockStore{
Store: newStore(t),
}
am := MocAccountManager{
store: store,
}
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
require.NoError(t, err, "failed to seed peers")
mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v)
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
require.NoError(t, err, "failed to get account peers")
for _, v := range peers {
mgr.OnPeerConnected(context.Background(), v)
}
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0")
require.NoError(t, err, "failed to get peer")
mgr.OnPeerDisconnected(context.Background(), peer)
startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup(context.Background())
peers, err = store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID)
require.NoError(t, err, "failed to get account peers")
expected := numberOfPeers + numberOfEphemeralPeers - 1
if len(store.account.Peers) != expected {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
require.Equal(t, expected, len(peers), "failed to cleanup ephemeral peers")
}
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
store.account = newAccountWithId(context.Background(), "my account", "", "")
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) error {
accountID := "my account"
err := newAccountWithId(context.Background(), store, accountID, "", "")
if err != nil {
return err
}
store.accountID = accountID
for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i)
p := &nbpeer.Peer{
ID: peerId,
AccountID: accountID,
Ephemeral: false,
}
store.account.Peers[p.ID] = p
err = store.AddPeerToAccount(context.Background(), p)
if err != nil {
return err
}
}
for i := 0; i < numberOfEphemeralPeers; i++ {
peerId := fmt.Sprintf("ephemeral_peer_%d", i)
p := &nbpeer.Peer{
ID: peerId,
AccountID: accountID,
Ephemeral: true,
}
store.account.Peers[p.ID] = p
err = store.AddPeerToAccount(context.Background(), p)
if err != nil {
return err
}
}
return nil
}

View File

@@ -37,8 +37,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
return err
}
if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() && settings.RegularUsersViewBlocked {
return status.NewAdminPermissionError()
}
return nil
@@ -50,7 +54,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
return nil, err
}
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
}
// GetAllGroups returns all groups in an account
@@ -59,31 +63,34 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
return nil, err
}
return am.Store.GetAccountGroups(ctx, accountID)
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
}
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup})
}
// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
var eventsToStore []func()
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
var (
eventsToStore []func()
groupsToSave []*nbgroup.Group
)
for _, newGroup := range newGroups {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
@@ -91,7 +98,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
@@ -109,15 +116,15 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
oldGroup := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
newGroup.AccountID = accountID
groupsToSave = append(groupsToSave, newGroup)
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
events := am.prepareGroupEvents(ctx, userID, accountID, newGroup)
eventsToStore = append(eventsToStore, events...)
}
@@ -126,30 +133,45 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
newGroupIDs = append(newGroupIDs, newGroup.ID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, newGroupIDs)
if err != nil {
return err
}
if areGroupChangesAffectPeers(account, newGroupIDs) {
am.updateAccountPeers(ctx, account)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil {
return fmt.Errorf("failed to save groups: %w", err)
}
return nil
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func()
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
if oldGroup != nil {
oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
@@ -159,12 +181,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
})
}
for _, p := range addedPeers {
peer := account.Peers[p]
if peer == nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
for _, peerID := range addedPeers {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue
}
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
@@ -175,12 +198,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
})
}
for _, p := range removedPeers {
peer := account.Peers[p]
if peer == nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
for _, peerID := range removedPeers {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue
}
peerCopy := peer // copy to avoid closure issues
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
@@ -210,119 +234,108 @@ func difference(a, b []string) []string {
}
// DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountId)
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return nil
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
allGroup, err := account.GetGroupAll()
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
if allGroup.ID == groupID {
if group.Name == "All" {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
if err = am.validateDeleteGroup(ctx, group, userID); err != nil {
return err
}
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta())
return nil
}
// DeleteGroups deletes groups from an account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
account, err := am.Store.GetAccount(ctx, accountId)
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
var allErrors error
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
var (
allErrors error
groupIDsToDelete []string
deletedGroups []*nbgroup.Group
)
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
for _, groupID := range groupIDs {
group, ok := account.Groups[groupID]
if !ok {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
continue
}
if err := validateDeleteGroup(account, group, userId); err != nil {
if err := am.validateDeleteGroup(ctx, group, userID); err != nil {
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
continue
}
delete(account.Groups, groupID)
groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil
})
if err != nil {
return err
}
for _, g := range deletedGroups {
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
for _, group := range deletedGroups {
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
}
return allErrors
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true
for _, itemID := range group.Peers {
if itemID == peerID {
@@ -334,13 +347,27 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
group.Peers = append(group.Peers, peerID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID})
if err != nil {
return err
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
return nil
})
if err != nil {
return err
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -348,41 +375,55 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
updated := false
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err
}
updated = true
break
}
}
if areGroupChangesAffectPeers(account, []string{group.ID}) {
am.updateAccountPeers(ctx, account)
if !updated {
return nil
}
updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID})
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil {
return fmt.Errorf("failed to save group: %w", err)
}
return nil
})
if err != nil {
return err
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userID]
if executingUser == nil {
executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
@@ -390,32 +431,42 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string)
}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
if account.Settings.Extra != nil {
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
if settings.Extra != nil {
if slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
}
@@ -424,17 +475,30 @@ func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string)
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accountID string, groupID string) (bool, *route.Route) {
routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, accountID string, groupID string) (bool, *Policy) {
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
@@ -446,7 +510,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
@@ -454,11 +524,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
}
}
}
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
@@ -468,7 +545,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accountID string, groupID string) (bool, *User) {
users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
@@ -478,30 +561,46 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
}
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
func (am *DefaultAccountManager) anyGroupHasPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return false, err
}
}
return false
}
func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
return true
}
if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
return true
}
if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
return true
if group.HasPeers() {
return true, nil
}
}
return false
return false, nil
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := am.isGroupLinkedToDns(ctx, accountID, groupID); linked {
return true, nil
}
if linked, _ := am.isGroupLinkedToPolicy(ctx, accountID, groupID); linked {
return true, nil
}
if linked, _ := am.isGroupLinkedToRoute(ctx, accountID, groupID); linked {
return true, nil
}
}
return false, nil
}

View File

@@ -328,25 +328,30 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
}
routeResource := &route.Route{
ID: "example route",
Groups: []string{groupForRoute.ID},
ID: "example route",
AccountID: accountID,
Groups: []string{groupForRoute.ID},
}
routePeerGroupResource := &route.Route{
ID: "example route with peer groups",
AccountID: accountID,
PeerGroups: []string{groupForRoute2.ID},
}
nameServerGroup := &nbdns.NameServerGroup{
ID: "example name server group",
Groups: []string{groupForNameServerGroups.ID},
ID: "example name server group",
AccountID: accountID,
Groups: []string{groupForNameServerGroups.ID},
}
policy := &Policy{
ID: "example policy",
ID: "example policy",
AccountID: accountID,
Rules: []*PolicyRule{
{
ID: "example policy rule",
PolicyID: "example policy",
Destinations: []string{groupForPolicies.ID},
},
},
@@ -354,35 +359,60 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
setupKey := &SetupKey{
Id: "example setup key",
AccountID: accountID,
AutoGroups: []string{groupForSetupKeys.ID},
}
user := &User{
Id: "example user",
AccountID: accountID,
AutoGroups: []string{groupForUsers.ID},
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
account.Policies = append(account.Policies, policy)
account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user
err := am.Store.SaveAccount(context.Background(), account)
err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain)
if err != nil {
return nil, nil, err
}
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routeResource)
if err != nil {
return nil, nil, err
}
acc, err := am.Store.GetAccount(context.Background(), account.Id)
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routePeerGroupResource)
if err != nil {
return nil, nil, err
}
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameServerGroup)
if err != nil {
return nil, nil, err
}
err = am.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
if err != nil {
return nil, nil, err
}
err = am.Store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
if err != nil {
return nil, nil, err
}
err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, user)
if err != nil {
return nil, nil, err
}
err = am.SaveGroups(context.Background(), accountID, groupAdminUserID, []*nbgroup.Group{
groupForRoute, groupForRoute2, groupForNameServerGroups, groupForPolicies,
groupForSetupKeys, groupForUsers, groupForUsers, groupForIntegration,
})
if err != nil {
return nil, nil, err
}
acc, err := am.Store.GetAccount(context.Background(), accountID)
if err != nil {
return nil, nil, err
}
@@ -394,24 +424,28 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
ID: "groupA",
AccountID: account.Id,
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
},
{
ID: "groupB",
Name: "GroupB",
Peers: []string{},
ID: "groupB",
AccountID: account.Id,
Name: "GroupB",
Peers: []string{},
},
{
ID: "groupC",
Name: "GroupC",
Peers: []string{peer1.ID, peer3.ID},
ID: "groupC",
AccountID: account.Id,
Name: "GroupC",
Peers: []string{peer1.ID, peer3.ID},
},
{
ID: "groupD",
Name: "GroupD",
Peers: []string{},
ID: "groupD",
AccountID: account.Id,
Name: "GroupD",
Peers: []string{},
},
})
assert.NoError(t, err)
@@ -430,9 +464,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID},
ID: "groupB",
AccountID: account.Id,
Name: "GroupB",
Peers: []string{peer1.ID, peer2.ID},
})
assert.NoError(t, err)
@@ -501,10 +536,13 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
// adding a group to policy
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
Enabled: true,
ID: "policy",
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule",
PolicyID: "policy",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
@@ -524,9 +562,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
ID: "groupA",
AccountID: account.Id,
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
})
assert.NoError(t, err)
@@ -593,9 +632,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupC",
Name: "GroupC",
Peers: []string{peer1.ID, peer3.ID},
ID: "groupC",
AccountID: account.Id,
Name: "GroupC",
Peers: []string{peer1.ID, peer3.ID},
})
assert.NoError(t, err)
@@ -610,6 +650,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("saving group linked to route", func(t *testing.T) {
newRoute := route.Route{
ID: "route",
AccountID: account.Id,
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: "superNet",
NetworkType: route.IPv4Network,
@@ -634,9 +675,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
ID: "groupA",
AccountID: account.Id,
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
assert.NoError(t, err)
@@ -661,9 +703,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupD",
Name: "GroupD",
Peers: []string{peer1.ID},
ID: "groupD",
AccountID: account.Id,
Name: "GroupD",
Peers: []string{peer1.ID},
})
assert.NoError(t, err)

View File

@@ -100,13 +100,13 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
updatedAccountSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
resp := toAccountResponse(accountID, updatedAccountSettings)
util.WriteJSONObject(r.Context(), w, &resp)
}

View File

@@ -29,7 +29,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
return account.Settings, nil
},
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -39,9 +39,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
accCopy := account.Copy()
accCopy.UpdateSettings(newSettings)
return accCopy, nil
return newSettings.Copy(), nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@@ -49,7 +49,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -132,7 +132,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -180,7 +180,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -238,7 +238,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -68,7 +68,7 @@ func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
return nil, fmt.Errorf("unknown group name")
},
GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
GetUserPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return maps.Values(TestPeers), nil
},
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {

View File

@@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
)
authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT,
accountManager.GetAccountInfoFromPAT,
jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed,
accountManager.CheckUserAccessByJWTGroups,

View File

@@ -9,9 +9,9 @@ import (
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
@@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// GetAccountInfoFromPATFunc function
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
@@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
getAccountInfoFromPAT GetAccountInfoFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
@@ -47,7 +47,7 @@ const (
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" {
@@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
}
return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
getAccountInfoFromPAT: getAccountInfoFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
@@ -116,7 +116,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
return fmt.Errorf("error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(r.Context(), token)
@@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
return fmt.Errorf("error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
@@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
claimMaps := jwt.MapClaims{}
claimMaps[m.userIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information.

View File

@@ -33,7 +33,8 @@ var testAccount = &server.Account{
Domain: domain,
Users: map[string]*server.User{
userID: {
Id: userID,
Id: userID,
AccountID: accountID,
PATs: map[string]*server.PersonalAccessToken{
tokenID: {
ID: tokenID,
@@ -49,11 +50,11 @@ var testAccount = &server.Account{
},
}
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) {
if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
}
return nil, nil, nil, fmt.Errorf("PAT invalid")
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
@@ -165,7 +166,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
)
authMiddleware := NewAuthMiddleware(
mockGetAccountFromPAT,
mockGetAccountInfoFromPAT,
mockValidateAndParseToken,
mockMarkPATUsed,
mockCheckUserAccessByJWTGroups,

View File

@@ -120,6 +120,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
updatedNSGroup := &nbdns.NameServerGroup{
ID: nsGroupID,
AccountID: accountID,
Name: req.Name,
Description: req.Description,
Primary: req.Primary,

View File

@@ -48,8 +48,8 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
return peerToReturn, nil
}
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
if err != nil {
util.WriteError(ctx, err, w)
return
@@ -62,11 +62,16 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
}
dnsDomain := h.accountManager.GetDNSDomain()
groupsInfo := toGroupsInfo(account.Groups, peer.ID)
validPeers, err := h.accountManager.GetValidatedPeers(account)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(ctx, err, w)
return
}
groupsInfo := toGroupsInfo(peerGroups)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
return
}
@@ -75,7 +80,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
}
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -99,16 +104,21 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
}
}
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
if err != nil {
util.WriteError(ctx, err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil {
util.WriteError(ctx, err, w)
return
}
groupMinimumInfo := toGroupsInfo(peerGroups)
validPeers, err := h.accountManager.GetValidatedPeers(account)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
@@ -149,18 +159,11 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
case http.MethodDelete:
h.deletePeer(r.Context(), accountID, userID, peerID, w)
return
case http.MethodGet, http.MethodPut:
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if r.Method == http.MethodGet {
h.getPeer(r.Context(), account, peerID, userID, w)
} else {
h.updatePeer(r.Context(), account, userID, peerID, w, r)
}
case http.MethodGet:
h.getPeer(r.Context(), accountID, peerID, userID, w)
return
case http.MethodPut:
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
return
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
@@ -176,7 +179,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
peers, err := h.accountManager.ListPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -184,19 +187,25 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
for _, peer := range account.Peers {
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupMinimumInfo := toGroupsInfo(peerGroups)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
}
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -259,16 +268,16 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
}
}
dnsDomain := h.accountManager.GetDNSDomain()
validPeers, err := h.accountManager.GetValidatedPeers(account)
validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
dnsDomain := h.accountManager.GetDNSDomain()
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
@@ -303,26 +312,14 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
}
}
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{})
func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum {
groupsInfo := make([]api.GroupMinimum, 0, len(groups))
for _, group := range groups {
_, ok := groupsChecked[group.ID]
if ok {
continue
}
groupsChecked[group.ID] = struct{}{}
for _, pk := range group.Peers {
if pk == peerID {
info := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
groupsInfo = append(groupsInfo, info)
break
}
}
groupsInfo = append(groupsInfo, api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
})
}
return groupsInfo
}

View File

@@ -39,6 +39,68 @@ const (
)
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: "test_id",
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: "test_id",
Domain: "hotmail.com",
Peers: peersMap,
Users: map[string]*server.User{
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: "test_id",
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
IP: net.ParseIP("100.67.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
Serial: 51,
},
}
return &PeersHandler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
@@ -64,77 +126,37 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
}
return p, nil
},
GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
GetUserPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil
},
ListPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil
},
GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
peersID := make([]string, len(peers))
for _, peer := range peers {
peersID = append(peersID, peer.ID)
}
return []*nbgroup.Group{
{
ID: "group1",
AccountID: accountID,
Name: "group1",
Issued: "api",
Peers: peersID,
},
}, nil
},
GetDNSDomainFunc: func() string {
return "netbird.selfhosted"
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) {
return account, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: accountID,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: accountID,
Domain: "hotmail.com",
Peers: peersMap,
Users: map[string]*server.User{
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: accountID,
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
IP: net.ParseIP("100.67.0.0"),
Mask: net.IPv4Mask(255, 255, 0, 0),
},
Serial: 51,
},
}
return account, nil
},
HasConnectedChannelFunc: func(peerID string) bool {

View File

@@ -130,6 +130,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
policy := server.Policy{
ID: policyID,
AccountID: accountID,
Name: req.Name,
Enabled: req.Enabled,
Description: req.Description,

View File

@@ -163,13 +163,16 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
}
}
isUpdate := postureChecksID != ""
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
postureChecks.AccountID = accountID
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks, isUpdate); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -40,7 +40,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
}
return p, nil
},
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks, _ bool) error {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks

View File

@@ -4,6 +4,8 @@ import (
"context"
"errors"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
@@ -56,13 +58,15 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId
if len(groups) == 0 {
return true, nil
}
accountsGroups, err := am.ListGroups(ctx, accountId)
accountGroups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountId)
if err != nil {
return false, err
}
for _, group := range groups {
var found bool
for _, accountGroup := range accountsGroups {
for _, accountGroup := range accountGroups {
if accountGroup.ID == group {
found = true
break
@@ -76,6 +80,31 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId
return true, nil
}
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) {
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
groupsMap := make(map[string]*nbgroup.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra)
}

View File

@@ -461,7 +461,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 10 * time.Second,
Timeout: 2 * time.Second,
Timeout: 200 * time.Second,
}))
if err != nil {
return nil, nil, err

View File

@@ -22,16 +22,17 @@ import (
)
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
GetOrCreateAccountIDByUserFunc func(ctx context.Context, userId, domain string) (string, error)
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
GetUserPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
ListPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
@@ -45,16 +46,16 @@ type MockAccountManager struct {
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error)
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -89,15 +90,15 @@ type MockAccountManager struct {
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error)
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
@@ -123,7 +124,7 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncAndMarkPeer is not implemented")
}
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
@@ -131,7 +132,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
panic("implement me")
}
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
account, err := am.GetAccountFunc(ctx, accountID)
if err != nil {
return nil, err
}
approvedPeers := make(map[string]struct{})
for id := range account.Peers {
approvedPeers[id] = struct{}{}
@@ -171,16 +177,16 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID,
return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
}
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
func (am *MockAccountManager) GetOrCreateAccountByUser(
// GetOrCreateAccountIDByUser mock implementation of GetOrCreateAccountIDByUser from server.AccountManager interface
func (am *MockAccountManager) GetOrCreateAccountIDByUser(
ctx context.Context, userId, domain string,
) (*server.Account, error) {
if am.GetOrCreateAccountByUserFunc != nil {
return am.GetOrCreateAccountByUserFunc(ctx, userId, domain)
) (string, error) {
if am.GetOrCreateAccountIDByUserFunc != nil {
return am.GetOrCreateAccountIDByUserFunc(ctx, userId, domain)
}
return nil, status.Errorf(
return "", status.Errorf(
codes.Unimplemented,
"method GetOrCreateAccountByUser is not implemented",
"method GetOrCreateAccountIDByUser is not implemented",
)
}
@@ -222,19 +228,19 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId,
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
if am.GetAccountFromPATFunc != nil {
return am.GetAccountFromPATFunc(ctx, pat)
// GetAccountInfoFromPAT mock implementation of GetAccountInfoFromPAT from server.AccountManager interface
func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error) {
if am.GetAccountInfoFromPATFunc != nil {
return am.GetAccountInfoFromPATFunc(ctx, token)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetAccountInfoFromPAT is not implemented")
}
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
@@ -354,14 +360,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
}
// ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
}
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil {
@@ -626,12 +624,12 @@ func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, cl
return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
}
// GetPeers mocks GetPeers of the AccountManager interface
func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
if am.GetPeersFunc != nil {
return am.GetPeersFunc(ctx, accountID, userID)
// GetUserPeers mocks GetUserPeers of the AccountManager interface
func (am *MockAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
if am.GetUserPeersFunc != nil {
return am.GetUserPeersFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented")
return nil, status.Errorf(codes.Unimplemented, "method GetUserPeers is not implemented")
}
// GetDNSDomain mocks GetDNSDomain of the AccountManager interface
@@ -675,7 +673,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
}
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) {
if am.UpdateAccountSettingsFunc != nil {
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
}
@@ -691,9 +689,9 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo
}
// SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(ctx, sync, account)
return am.SyncPeerFunc(ctx, sync, accountID)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
}
@@ -739,9 +737,9 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
}
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error {
if am.SavePostureChecksFunc != nil {
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks, isUpdate)
}
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
}
@@ -840,3 +838,19 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
}
// GetPeerGroups mocks GetPeerGroups of the AccountManager interface
func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) {
if am.GetPeerGroupsFunc != nil {
return am.GetPeerGroupsFunc(ctx, accountID, peerID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
}
// ListPeers mocks ListPeers of the AccountManager interface
func (am *MockAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
if am.ListPeersFunc != nil {
return am.ListPeersFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListPeers is not implemented")
}

View File

@@ -3,7 +3,9 @@ package server
import (
"context"
"errors"
"fmt"
"regexp"
"slices"
"unicode/utf8"
"github.com/miekg/dns"
@@ -24,26 +26,31 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
if user.IsRegularUser() {
return nil, status.NewUnauthorizedToViewNSGroupsError()
}
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
}
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
newNSGroup := &nbdns.NameServerGroup{
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Description: description,
NameServers: nameServerList,
@@ -54,92 +61,136 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
SearchDomainsEnabled: searchDomainEnabled,
}
err = validateNameServerGroup(false, newNSGroup, account)
err = am.validateNameServerGroup(ctx, accountID, newNSGroup)
if err != nil {
return nil, err
}
if account.NameServerGroups == nil {
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
}
account.NameServerGroups[newNSGroup.ID] = newNSGroup
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
updateAccountPeers, err := am.anyGroupHasPeers(ctx, accountID, newNSGroup.Groups)
if err != nil {
return nil, err
}
if anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, account)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup); err != nil {
return fmt.Errorf("failed to create nameserver group: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return newNSGroup.Copy(), nil
}
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if nsGroupToSave == nil {
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
}
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
err = validateNameServerGroup(true, nsGroupToSave, account)
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
oldNSGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
if err != nil {
return err
}
nsGroupToSave.AccountID = accountID
if err = am.validateNameServerGroup(ctx, accountID, nsGroupToSave); err != nil {
return err
}
updateAccountPeers, err := am.areNameServerGroupChangesAffectPeers(ctx, nsGroupToSave, oldNSGroup)
if err != nil {
return err
}
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
if err = transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave); err != nil {
return fmt.Errorf("failed to update nameserver group: %w", err)
}
return nil
})
if err != nil {
return err
}
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, account)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
nsGroup := account.NameServerGroups[nsGroupID]
if nsGroup == nil {
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
delete(account.NameServerGroups, nsGroupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
if err != nil {
return err
}
if anyGroupHasPeers(account, nsGroup.Groups) {
am.updateAccountPeers(ctx, account)
updateAccountPeers, err := am.anyGroupHasPeers(ctx, accountID, nsGroup.Groups)
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID); err != nil {
return fmt.Errorf("failed to delete nameserver group: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
@@ -150,39 +201,44 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewUnauthorizedToViewNSGroupsError()
}
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
}
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
nsGroupID := ""
if existingGroup {
nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID]
if !found {
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
}
}
func (am *DefaultAccountManager) validateNameServerGroup(ctx context.Context, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
if err != nil {
return err
}
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
if err != nil {
return err
}
err = validateNSList(nameserverGroup.NameServers)
if err != nil {
return err
}
err = validateGroups(nameserverGroup.Groups, account.Groups)
nsServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups)
if err != nil {
return err
}
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
err = validateGroups(nameserverGroup.Groups, groups)
if err != nil {
return err
}
@@ -190,6 +246,24 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
return nil
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(ctx context.Context, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}
hasPeers, err := am.anyGroupHasPeers(ctx, newNSGroup.AccountID, newNSGroup.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return am.anyGroupHasPeers(ctx, oldNSGroup.AccountID, oldNSGroup.Groups)
}
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
if !primary && len(domains) == 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
@@ -213,14 +287,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
return nil
}
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
}
for _, nsGroup := range nsGroupMap {
for _, nsGroup := range groups {
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name)
}
}
@@ -228,14 +302,14 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
}
func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list)
if nsListLenght == 0 || nsListLenght > 3 {
nsListLength := len(list)
if nsListLength == 0 || nsListLength > 3 {
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
}
return nil
}
func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
func validateGroups(list []string, groups []*nbgroup.Group) error {
if len(list) == 0 {
return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
}
@@ -244,13 +318,8 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
if id == "" {
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
}
found := false
for groupID := range groups {
if id == groupID {
found = true
break
}
}
found := slices.ContainsFunc(groups, func(group *nbgroup.Group) bool { return group.ID == id })
if !found {
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
}
@@ -277,11 +346,3 @@ func validateDomain(domain string) error {
return nil
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false
}
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
}

View File

@@ -6,6 +6,7 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -381,14 +382,14 @@ func TestCreateNameServerGroup(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestNSAccount(t, am)
accountID, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
outNSGroup, err := am.CreateNameServerGroup(
context.Background(),
account.Id,
accountID,
testCase.inputArgs.name,
testCase.inputArgs.description,
testCase.inputArgs.nameServers,
@@ -408,7 +409,7 @@ func TestCreateNameServerGroup(t *testing.T) {
// assign generated ID
testCase.expectedNSGroup.ID = outNSGroup.ID
testCase.expectedNSGroup.AccountID = accountID
if !testCase.expectedNSGroup.IsEqual(outNSGroup) {
t.Errorf("new nameserver group didn't match expected ns group:\nGot %#v\nExpected:%#v\n", outNSGroup, testCase.expectedNSGroup)
}
@@ -609,20 +610,16 @@ func TestSaveNameServerGroup(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestNSAccount(t, am)
accountID, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Error("account should be saved")
}
testCase.existingNSGroup.AccountID = accountID
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testCase.existingNSGroup)
require.NoError(t, err, "failed to save existing nameserver group")
var nsGroupToSave *nbdns.NameServerGroup
if !testCase.skipCopying {
nsGroupToSave = testCase.existingNSGroup.Copy()
@@ -651,22 +648,17 @@ func TestSaveNameServerGroup(t *testing.T) {
}
}
err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave)
err = am.SaveNameServerGroup(context.Background(), accountID, userID, nsGroupToSave)
testCase.errFunc(t, err)
if !testCase.shouldCreate {
return
}
account, err = am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
}
savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID]
require.True(t, saved)
savedNSGroup, err := am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testCase.expectedNSGroup.ID)
require.NoError(t, err, "failed to get saved nameserver group")
testCase.expectedNSGroup.AccountID = accountID
if !testCase.expectedNSGroup.IsEqual(savedNSGroup) {
t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup)
}
@@ -703,32 +695,25 @@ func TestDeleteNameServerGroup(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestNSAccount(t, am)
accountID, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
account.NameServerGroups[testingNSGroup.ID] = testingNSGroup
testingNSGroup.AccountID = accountID
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testingNSGroup)
require.NoError(t, err, "failed to save nameserver group")
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Error("failed to save account")
}
err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID)
err = am.DeleteNameServerGroup(context.Background(), accountID, testingNSGroup.ID, userID)
if err != nil {
t.Error("deleting nameserver group failed with error: ", err)
}
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Error("failed to retrieve saved account with error: ", err)
}
_, found := savedAccount.NameServerGroups[testingNSGroup.ID]
if found {
t.Error("nameserver group shouldn't be found after delete")
}
_, err = am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testingNSGroup.ID)
require.NotNil(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok, "error should be a status error")
assert.Equal(t, status.NotFound, sErr.Type(), "nameserver group shouldn't be found after delete")
}
func TestGetNameServerGroup(t *testing.T) {
@@ -738,12 +723,12 @@ func TestGetNameServerGroup(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestNSAccount(t, am)
accountID, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID)
foundGroup, err := am.GetNameServerGroup(context.Background(), accountID, testUserID, existingNSGroupID)
if err != nil {
t.Error("getting existing nameserver group failed with error: ", err)
}
@@ -752,7 +737,7 @@ func TestGetNameServerGroup(t *testing.T) {
t.Error("got a nil group while getting nameserver group with ID")
}
_, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing")
_, err = am.GetNameServerGroup(context.Background(), accountID, testUserID, "not existing")
if err == nil {
t.Error("getting not existing nameserver group should return error, got nil")
}
@@ -784,8 +769,12 @@ func createNSStore(t *testing.T) (Store, error) {
return store, nil
}
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
t.Helper()
accountID := "testingAcc"
userID := testUserID
domain := "example.com"
peer1 := &nbpeer.Peer{
Key: nsGroupPeer1Key,
Name: "test-host1@netbird.io",
@@ -816,6 +805,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
}
existingNSGroup := nbdns.NameServerGroup{
ID: existingNSGroupID,
AccountID: accountID,
Name: existingNSGroupName,
Description: "",
NameServers: []nbdns.NameServer{
@@ -834,42 +824,42 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
Enabled: true,
}
accountID := "testingAcc"
userID := testUserID
domain := "example.com"
account := newAccountWithId(context.Background(), accountID, userID, domain)
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
newGroup1 := &nbgroup.Group{
ID: group1ID,
Name: group1ID,
}
newGroup2 := &nbgroup.Group{
ID: group2ID,
Name: group2ID,
}
account.Groups[newGroup1.ID] = newGroup1
account.Groups[newGroup2.ID] = newGroup2
err := am.Store.SaveAccount(context.Background(), account)
err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
if err != nil {
return nil, err
return "", err
}
err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &existingNSGroup)
if err != nil {
return "", err
}
err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*nbgroup.Group{
{
ID: group1ID,
AccountID: accountID,
Name: group1ID,
},
{
ID: group2ID,
AccountID: accountID,
Name: group2ID,
},
})
if err != nil {
return "", err
}
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer1)
if err != nil {
return nil, err
return "", err
}
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer2)
if err != nil {
return nil, err
return "", err
}
return account, nil
return accountID, nil
}
func TestValidateDomain(t *testing.T) {

View File

@@ -11,6 +11,7 @@ import (
"sync"
"time"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
@@ -50,34 +51,54 @@ type PeerLogin struct {
ConnectionIP net.IP
}
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// ListPeers returns a list of peers under the given account.
func (am *DefaultAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
return am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
}
// GetUserPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
account, err := am.Store.GetAccount(ctx, accountID)
func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
approvedPeersMap, err := am.GetValidatedPeers(account)
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil {
return nil, err
}
peers := make([]*nbpeer.Peer, 0)
peersMap := make(map[string]*nbpeer.Peer)
regularUser := !user.HasAdminPower() && !user.IsServiceUser
if regularUser && account.Settings.RegularUsersViewBlocked {
if user.IsRegularUser() && settings.RegularUsersViewBlocked {
return peers, nil
}
for _, peer := range account.Peers {
if regularUser && user.Id != peer.UserID {
accountPeers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
for _, peer := range accountPeers {
if user.IsRegularUser() && user.Id != peer.UserID {
// only display peers that belong to the current user if the current user is not an admin
continue
}
@@ -86,10 +107,15 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
peersMap[peer.ID] = p
}
if !regularUser {
if user.IsAdminOrServiceUser() {
return peers, nil
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, fmt.Errorf(errGetAccountFmt, err)
}
// fetch all the peers that have access to the user's peers
for _, peer := range peers {
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
@@ -107,37 +133,42 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
}
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error {
peer, err := account.FindPeerByPubKey(peerPubKey)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peerPubKey)
if err != nil {
return err
}
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, accountID)
if err != nil {
return err
}
if peer.AddedWithSSOLogin() {
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
}
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
}
if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
}
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) {
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = time.Now().UTC()
@@ -157,16 +188,14 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID
err = am.Store.SavePeerLocation(account.Id, peer)
err = am.Store.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
}
}
}
account.UpdatePeer(peer)
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus)
if err != nil {
return false, err
}
@@ -176,39 +205,50 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
peer := account.GetPeer(update.ID)
if peer == nil {
return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID)
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, update.ID)
if err != nil {
return nil, err
}
update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra)
if err != nil {
return nil, err
}
var sshChanged, peerLabelChanged, loginExpirationChanged, inactivityExpirationChanged bool
if peer.SSHEnabled != update.SSHEnabled {
peer.SSHEnabled = update.SSHEnabled
event := activity.PeerSSHEnabled
if !update.SSHEnabled {
event = activity.PeerSSHDisabled
}
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
sshChanged = true
}
peerLabelUpdated := peer.Name != update.Name
if peerLabelUpdated {
if peer.Name != update.Name {
peer.Name = update.Name
peerLabelChanged = true
existingLabels := account.getPeerDNSLabels()
existingLabels, err := am.getPeerDNSLabels(ctx, accountID)
if err != nil {
return nil, err
}
newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
if err != nil {
@@ -216,134 +256,107 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
peer.DNSLabel = newLabel
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain()))
}
if peer.LoginExpirationEnabled != update.LoginExpirationEnabled {
if !peer.AddedWithSSOLogin() {
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
}
peer.LoginExpirationEnabled = update.LoginExpirationEnabled
loginExpirationChanged = true
}
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
if !peer.AddedWithSSOLogin() {
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the inactivity expiration can't be updated")
}
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
inactivityExpirationChanged = true
}
if err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil {
return nil, err
}
if sshChanged {
event := activity.PeerSSHEnabled
if !peer.SSHEnabled {
event = activity.PeerSSHDisabled
}
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
}
if peerLabelChanged {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain()))
am.updateAccountPeers(ctx, accountID)
}
if loginExpirationChanged {
event := activity.PeerLoginExpirationEnabled
if !update.LoginExpirationEnabled {
if !peer.LoginExpirationEnabled {
event = activity.PeerLoginExpirationDisabled
}
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
}
}
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
if !peer.AddedWithSSOLogin() {
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
}
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
if inactivityExpirationChanged {
event := activity.PeerInactivityExpirationEnabled
if !update.InactivityExpirationEnabled {
if !peer.InactivityExpirationEnabled {
event = activity.PeerInactivityExpirationDisabled
}
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
}
account.UpdatePeer(peer)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
if peerLabelUpdated {
am.updateAccountPeers(ctx, account)
}
return peer, nil
}
// deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock
func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error {
// the first loop is needed to ensure all peers present under the account before modifying, otherwise
// we might have some inconsistencies
peers := make([]*nbpeer.Peer, 0, len(peerIDs))
for _, peerID := range peerIDs {
peer := account.GetPeer(peerID)
if peer == nil {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
peers = append(peers, peer)
}
// the 2nd loop performs the actual modification
for _, peer := range peers {
err := am.integratedPeerValidator.PeerDeleted(ctx, account.Id, peer.ID)
if err != nil {
return err
}
account.DeletePeer(peer.ID)
am.peersUpdateManager.SendUpdate(ctx, peer.ID,
&UpdateMessage{
Update: &proto.SyncResponse{
// fill those field for backward compatibility
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
// new field
NetworkMap: &proto.NetworkMap{
Serial: account.Network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
},
},
NetworkMap: &NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
}
return nil
}
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID)
if err != nil {
return err
}
updateAccountPeers := isPeerInActiveGroup(account, peerID)
if peerAccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
err = am.deletePeers(ctx, account, []string{peerID}, userID)
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, peerID)
if err != nil {
return err
}
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err
var peer *nbpeer.Peer
var addPeerRemovedEvents []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
peer, err = transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return fmt.Errorf("failed to get peer to delete: %w", err)
}
addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
}
return nil
})
for _, addPeerRemovedEvent := range addPeerRemovedEvents {
addPeerRemovedEvent()
}
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -405,7 +418,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
addedByUser := false
if len(userID) > 0 {
addedByUser = true
accountID, err = am.Store.GetAccountIDByUserID(userID)
accountID, err = am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID)
} else {
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
}
@@ -436,12 +449,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
var newPeer *nbpeer.Peer
var groupsToAdd []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
var setupKeyID string
var setupKeyName string
var ephemeral bool
var groupsToAdd []string
if addedByUser {
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
if err != nil {
@@ -550,7 +563,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to add peer to account: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
@@ -584,30 +597,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
unlock()
unlock = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
}
allGroup, err := account.GetGroupAll()
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
}
groupsToAdd = append(groupsToAdd, allGroup.ID)
if areGroupChangesAffectPeers(account, groupsToAdd) {
am.updateAccountPeers(ctx, account)
}
approvedPeersMap, err := am.GetValidatedPeers(account)
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, newPeer.ID)
if err != nil {
return nil, nil, nil, err
}
postureChecks := am.getPeerPostureChecks(account, newPeer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
}
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
@@ -630,14 +629,14 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, sync.WireGuardPubKey)
if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
if peer.UserID != "" {
user, err := account.FindUser(peer.UserID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil {
return nil, nil, nil, err
}
@@ -648,48 +647,38 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
}
if peerLoginExpired(ctx, peer, account.Settings) {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
if peerLoginExpired(ctx, peer, settings) {
return nil, nil, nil, status.NewPeerLoginExpiredError()
}
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, peer.ID)
if err != nil {
return nil, nil, nil, err
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupList, settings.Extra)
if err != nil {
return nil, nil, nil, err
}
updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
err = am.Store.SavePeer(ctx, account.Id, peer)
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil {
return nil, nil, nil, err
}
if sync.UpdateAccountPeers {
am.updateAccountPeers(ctx, account)
}
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, nil, nil, err
if isStatusChanged || (updated && sync.UpdateAccountPeers) {
am.updateAccountPeers(ctx, accountID)
}
var postureChecks []*posture.Checks
if peerNotValid {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
return peer, emptyMap, postureChecks, nil
}
if isStatusChanged {
am.updateAccountPeers(ctx, account)
}
validPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, nil, err
}
postureChecks = am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
}
// LoginPeer logs in or registers a peer.
@@ -764,7 +753,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -795,7 +784,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
if shouldStorePeer {
err = am.Store.SavePeer(ctx, accountID, peer)
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil {
return nil, nil, nil, err
}
@@ -804,16 +793,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
unlockPeer()
unlockPeer = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
}
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
@@ -845,21 +829,33 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
var postureChecks []*posture.Checks
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
Network: network.Copy(),
}
return peer, emptyMap, nil, nil
}
approvedPeersMap, err := am.GetValidatedPeers(account)
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil {
return nil, nil, nil, err
}
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, peer.ID)
if err != nil {
return nil, nil, nil, err
}
postureChecks = am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
@@ -873,7 +869,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
peer = peer.UpdateLastLogin()
err = am.Store.SavePeer(ctx, peer.AccountID, peer)
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer)
if err != nil {
return err
}
@@ -920,45 +916,51 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
// GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
if user.IsRegularUser() && settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID)
}
peer := account.GetPeer(peerID)
if peer == nil {
return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID)
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return nil, err
}
// if admin or user owns this peer, return peer
if user.HasAdminPower() || user.IsServiceUser || peer.UserID == userID {
if user.IsAdminOrServiceUser() || peer.UserID == userID {
return peer, nil
}
// it is also possible that user doesn't own the peer but some of his peers have access to it,
// this is a valid case, show the peer as well.
userPeers, err := account.FindUserPeers(userID)
userPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
if err != nil {
return nil, err
}
approvedPeersMap, err := am.GetValidatedPeers(account)
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil {
return nil, err
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, fmt.Errorf(errGetAccountFmt, err)
}
for _, p := range userPeers {
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap)
for _, aclPeer := range aclPeers {
@@ -973,7 +975,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
start := time.Now()
defer func() {
if am.metrics != nil {
@@ -981,9 +983,15 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
}
}()
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
return
}
peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account)
approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
return
@@ -1007,7 +1015,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
defer wg.Done()
defer func() { <-semaphore }()
postureChecks := am.getPeerPostureChecks(account, p)
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
return
}
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
@@ -1017,6 +1030,236 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait()
}
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected.
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
return 0, false
}
if len(peersWithExpiry) == 0 {
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithExpiry {
// consider only connected peers because others will require login on connecting to the management server
if peer.Status.LoginExpired || !peer.Status.Connected {
continue
}
_, duration := peer.LoginExpired(settings.PeerLoginExpiration)
if nextExpiry == nil || duration < *nextExpiry {
// if expiration is below 1s return 1s duration
// this avoids issues with ticker that can't be set to < 0
if duration < time.Second {
return time.Second, true
}
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected.
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
return 0, false
}
if len(peersWithInactivity) == 0 {
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithInactivity {
if peer.Status.LoginExpired || peer.Status.Connected {
continue
}
_, duration := peer.SessionExpired(settings.PeerInactivityExpiration)
if nextExpiry == nil || duration < *nextExpiry {
// if expiration is below 1s return 1s duration
// this avoids issues with ticker that can't be set to < 0
if duration < time.Second {
return time.Second, true
}
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// getExpiredPeers returns peers that have been expired.
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
var peers []*nbpeer.Peer
for _, peer := range peersWithExpiry {
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
if expired {
peers = append(peers, peer)
}
}
return peers, nil
}
// getInactivePeers returns peers that have been expired by inactivity
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
var peers []*nbpeer.Peer
for _, inactivePeer := range peersWithInactivity {
inactive, _ := inactivePeer.SessionExpired(settings.PeerInactivityExpiration)
if inactive {
peers = append(peers, inactivePeer)
}
}
return peers, nil
}
// GetPeerGroups returns groups that the peer is part of.
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerGroups := make([]*nbgroup.Group, 0)
for _, group := range groups {
if slices.Contains(group.Peers, peerID) {
peerGroups = append(peerGroups, group)
}
}
return peerGroups, nil
}
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID string, peerID string) ([]string, error) {
groups, err := am.GetPeerGroups(ctx, accountID, peerID)
if err != nil {
return nil, err
}
groupIDs := make([]string, 0, len(groups))
for _, group := range groups {
groupIDs = append(groupIDs, group.ID)
}
return groupIDs, err
}
func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID string) (lookupMap, error) {
dnsLabels, err := am.Store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
existingLabels := make(lookupMap)
for _, label := range dnsLabels {
existingLabels[label] = struct{}{}
}
return existingLabels, nil
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, accountID, peerID string) (bool, error) {
peerGroupIDs, err := am.getPeerGroupIDs(ctx, accountID, peerID)
if err != nil {
return false, err
}
return am.areGroupChangesAffectPeers(ctx, accountID, peerGroupIDs)
}
// deletePeers deletes all specified peers and sends updates to the remote peers.
// Returns a slice of functions to save events after successful peer deletion.
func deletePeers(ctx context.Context, am *DefaultAccountManager, store Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
var peerDeletedEvents []func()
for _, peer := range peers {
if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil {
return nil, fmt.Errorf("failed to validate peer: %w", err)
}
network, err := store.GetAccountNetwork(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get account network: %w", err)
}
if err = store.DeletePeer(ctx, LockingStrengthUpdate, accountID, peer.ID); err != nil {
return nil, fmt.Errorf("failed to delete peer: %w", err)
}
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
},
},
NetworkMap: &NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
})
}
return peerDeletedEvents, nil
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels {
@@ -1024,15 +1267,3 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
}
return labelMap
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(account *Account, peerID string) bool {
peerGroupIDs := make([]string, 0)
for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID)
}
}
return areGroupChangesAffectPeers(account, peerGroupIDs)
}

View File

@@ -44,7 +44,7 @@ type Peer struct {
// CreatedAt records the time the peer was created
CreatedAt time.Time
// Indicate ephemeral peer attribute
Ephemeral bool
Ephemeral bool `gorm:"index"`
// Geo location based on connection IP
Location Location `gorm:"embedded;embeddedPrefix:location_"`
}

View File

@@ -467,21 +467,25 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
accountID := "test_account"
adminUser := "account_creator"
someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[someUser] = &User{
Id: someUser,
Role: UserRoleUser,
}
account.Settings.RegularUsersViewBlocked = false
err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
require.NoError(t, err, "failed to create account")
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatal(err)
return
}
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: someUser,
AccountID: accountID,
Role: UserRoleUser,
})
require.NoError(t, err, "failed to create user")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account settings")
settings.RegularUsersViewBlocked = false
err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings)
require.NoError(t, err, "failed to save account settings")
// two peers one added by a regular user and one with a setup key
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
if err != nil {
t.Fatal("error creating setup key")
return
@@ -535,7 +539,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
assert.NotNil(t, peer)
// delete the all-to-all policy so that user's peer1 has no access to peer2
for _, policy := range account.Policies {
accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account policies")
for _, policy := range accountPolicies {
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
if err != nil {
t.Fatal(err)
@@ -563,7 +570,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
assert.NotNil(t, peer)
}
func TestDefaultAccountManager_GetPeers(t *testing.T) {
func TestDefaultAccountManager_GetUserPeers(t *testing.T) {
testCases := []struct {
name string
role UserRole
@@ -654,21 +661,33 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
accountID := "test_account"
adminUser := "account_creator"
someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[someUser] = &User{
err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
require.NoError(t, err, "failed to create account")
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: someUser,
AccountID: accountID,
Role: testCase.role,
IsServiceUser: testCase.isServiceUser,
}
account.Policies = []*Policy{}
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
})
require.NoError(t, err, "failed to create user")
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatal(err)
return
accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account policies")
for _, policy := range accountPolicies {
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
require.NoError(t, err, "failed to delete policy")
}
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "failed to get account settings")
settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings)
require.NoError(t, err, "failed to save account settings")
peerKey1, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
@@ -699,7 +718,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
return
}
peers, err := manager.GetPeers(context.Background(), accountID, someUser)
peers, err := manager.GetUserPeers(context.Background(), accountID, someUser)
if err != nil {
t.Fatal(err)
return
@@ -724,10 +743,18 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
adminUser := "account_creator"
regularUser := "regular_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[regularUser] = &User{
Id: regularUser,
Role: UserRoleUser,
err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "")
if err != nil {
return nil, "", "", err
}
err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: regularUser,
AccountID: accountID,
Role: UserRoleUser,
})
if err != nil {
return nil, "", "", err
}
// Create peers
@@ -741,31 +768,40 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
Status: &nbpeer.PeerStatus{},
UserID: regularUser,
}
account.Peers[peer.ID] = peer
err = manager.Store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer)
if err != nil {
return nil, "", "", err
}
}
// Create groups and policies
account.Policies = make([]*Policy, 0, groups)
for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i)
group := &nbgroup.Group{
ID: groupID,
Name: fmt.Sprintf("Group %d", i),
ID: groupID,
AccountID: accountID,
Name: fmt.Sprintf("Group %d", i),
}
for j := 0; j < peers/groups; j++ {
peerIndex := i*(peers/groups) + j
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
}
account.Groups[groupID] = group
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
return nil, "", "", err
}
// Create a policy for this group
policy := &Policy{
ID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Policy for Group %d", i),
Enabled: true,
ID: fmt.Sprintf("policy-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("Policy for Group %d", i),
Enabled: true,
Rules: []*PolicyRule{
{
ID: fmt.Sprintf("rule-%d", i),
PolicyID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Rule for Group %d", i),
Enabled: true,
Sources: []string{groupID},
@@ -776,22 +812,23 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
},
},
}
account.Policies = append(account.Policies, policy)
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
if err != nil {
return nil, "", "", err
}
}
account.PostureChecks = []*posture.Checks{
{
ID: "PostureChecksAll",
Name: "All",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, &posture.Checks{
ID: "PostureChecksAll",
AccountID: accountID,
Name: "All",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
}
err = manager.Store.SaveAccount(context.Background(), account)
})
if err != nil {
return nil, "", "", err
}
@@ -824,9 +861,9 @@ func BenchmarkGetPeers(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := manager.GetPeers(context.Background(), accountID, userID)
_, err := manager.GetUserPeers(context.Background(), accountID, userID)
if err != nil {
b.Fatalf("GetPeers failed: %v", err)
b.Fatalf("GetUserPeers failed: %v", err)
}
}
})
@@ -876,7 +913,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
start := time.Now()
for i := 0; i < b.N; i++ {
manager.updateAccountPeers(ctx, account)
manager.updateAccountPeers(ctx, accountID)
}
duration := time.Since(start)
@@ -1401,10 +1438,13 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
// Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) {
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
ID: "policy",
Enabled: true,
ID: "policy",
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule",
PolicyID: "policy",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},

View File

@@ -8,9 +8,8 @@ import (
"time"
b "github.com/hashicorp/go-secure-stdlib/base62"
"github.com/rs/xid"
"github.com/netbirdio/netbird/base62"
"github.com/rs/xid"
)
const (
@@ -58,7 +57,7 @@ type PersonalAccessTokenGenerated struct {
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) {
func CreateNewPAT(name string, expirationInDays int, targetUserID, createdBy string) (*PersonalAccessTokenGenerated, error) {
hashedToken, plainToken, err := generateNewToken()
if err != nil {
return nil, err
@@ -67,6 +66,7 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona
return &PersonalAccessTokenGenerated{
PersonalAccessToken: PersonalAccessToken{
ID: xid.New().String(),
UserID: targetUserID,
Name: name,
HashedToken: hashedToken,
ExpirationDate: currentTime.AddDate(0, 0, expirationInDays),

View File

@@ -3,13 +3,13 @@ package server
import (
"context"
_ "embed"
"slices"
"fmt"
"strconv"
"strings"
"github.com/netbirdio/netbird/management/proto"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -171,6 +171,7 @@ type Policy struct {
func (p *Policy) Copy() *Policy {
c := &Policy{
ID: p.ID,
AccountID: p.AccountID,
Name: p.Name,
Description: p.Description,
Enabled: p.Enabled,
@@ -211,7 +212,6 @@ func (p *Policy) ruleGroups() []string {
groups = append(groups, rule.Sources...)
groups = append(groups, rule.Destinations...)
}
return groups
}
@@ -343,30 +343,73 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
}
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
postureChecks, err := am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
for index, rule := range policy.Rules {
rule.Sources = getValidGroupIDs(groups, rule.Sources)
rule.Destinations = getValidGroupIDs(groups, rule.Destinations)
policy.Rules[index] = rule
}
if policy.SourcePostureChecks != nil {
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
}
updateAccountPeers, err := am.arePolicyChangesAffectPeers(ctx, policy, isUpdate)
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
saveFunc := transaction.SavePolicy
if !isUpdate {
saveFunc = transaction.CreatePolicy
}
if err := saveFunc(ctx, LockingStrengthUpdate, policy); err != nil {
return fmt.Errorf("failed to save policy: %w", err)
}
return nil
})
if err != nil {
return err
}
@@ -377,7 +420,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -385,115 +428,91 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
policy, err := am.deletePolicy(account, policyID)
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
policy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
if err != nil {
return err
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
updateAccountPeers, err := am.arePolicyChangesAffectPeers(ctx, policy, false)
if err != nil {
return err
}
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, account)
if err = transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID); err != nil {
return fmt.Errorf("failed to delete policy: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// ListPolicies from the store
// ListPolicies from the store.
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
policyIdx := -1
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
if policyIdx < 0 {
return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID)
}
policy := account.Policies[policyIdx]
account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...)
return policy, nil
}
// savePolicy saves or updates a policy in the given account.
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
for index, rule := range policyToSave.Rules {
rule.Sources = filterValidGroupIDs(account, rule.Sources)
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
policyToSave.Rules[index] = rule
}
if policyToSave.SourcePostureChecks != nil {
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
}
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
func (am *DefaultAccountManager) arePolicyChangesAffectPeers(ctx context.Context, policy *Policy, isUpdate bool) (bool, error) {
if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
if policyIdx < 0 {
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
existingPolicy, err := am.Store.GetPolicyByID(ctx, LockingStrengthShare, policy.AccountID, policy.ID)
if err != nil {
return false, err
}
oldPolicy := account.Policies[policyIdx]
// Update the existing policy
account.Policies[policyIdx] = policyToSave
if !policyToSave.Enabled && !oldPolicy.Enabled {
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
return updateAccountPeers, nil
}
// Add the new policy to the account
account.Policies = append(account.Policies, policyToSave)
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
}
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.FirewallRule{
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
hasPeers, err := am.anyGroupHasPeers(ctx, policy.AccountID, existingPolicy.ruleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return am.anyGroupHasPeers(ctx, policy.AccountID, policy.ruleGroups())
}
return result
return am.anyGroupHasPeers(ctx, policy.AccountID, policy.ruleGroups())
}
// getAllPeersFromGroups for given peer ID and list of groups
@@ -574,27 +593,52 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
return nil
}
// filterValidPostureChecks filters and returns the posture check IDs from the given list
// that are valid within the provided account.
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
result := make([]string, 0, len(postureChecksIds))
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string {
validPostureCheckIDs := make(map[string]struct{})
for _, check := range postureChecks {
validPostureCheckIDs[check.ID] = struct{}{}
}
validIDs := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks {
if id == postureCheck.ID {
result = append(result, id)
continue
}
if _, exists := validPostureCheckIDs[id]; exists {
validIDs = append(validIDs, id)
}
}
return result
return validIDs
}
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
result := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs {
if _, exists := account.Groups[groupID]; exists {
result = append(result, groupID)
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string {
validGroupIDs := make(map[string]struct{})
for _, group := range groups {
validGroupIDs[group.ID] = struct{}{}
}
validIDs := make([]string, 0, len(groupIDs))
for _, id := range groupIDs {
if _, exists := validGroupIDs[id]; exists {
validIDs = append(validIDs, id)
}
}
return validIDs
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.FirewallRule{
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
}
return result

View File

@@ -832,24 +832,28 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer3.ID},
ID: "groupA",
AccountID: account.Id,
Name: "GroupA",
Peers: []string{peer1.ID, peer3.ID},
},
{
ID: "groupB",
Name: "GroupB",
Peers: []string{},
ID: "groupB",
AccountID: account.Id,
Name: "GroupB",
Peers: []string{},
},
{
ID: "groupC",
Name: "GroupC",
Peers: []string{},
ID: "groupC",
AccountID: account.Id,
Name: "GroupC",
Peers: []string{},
},
{
ID: "groupD",
Name: "GroupD",
Peers: []string{peer1.ID, peer2.ID},
ID: "groupD",
AccountID: account.Id,
Name: "GroupD",
Peers: []string{peer1.ID, peer2.ID},
},
})
assert.NoError(t, err)
@@ -862,11 +866,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
policy := Policy{
ID: "policy-rule-groups-no-peers",
Enabled: true,
ID: "policy-rule-groups-no-peers",
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
PolicyID: "policy-rule-groups-no-peers",
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupC"},
@@ -896,11 +902,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// update account's peers and send peer update
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
policy := Policy{
ID: "policy-source-has-peers-destination-none",
Enabled: true,
ID: "policy-source-has-peers-destination-none",
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
PolicyID: "policy-source-has-peers-destination-none",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupB"},
@@ -931,11 +939,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// update account's peers and send peer update
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
policy := Policy{
ID: "policy-destination-has-peers-source-none",
Enabled: true,
ID: "policy-destination-has-peers-source-none",
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
PolicyID: "policy-destination-has-peers-source-none",
Enabled: false,
Sources: []string{"groupC"},
Destinations: []string{"groupD"},
@@ -966,11 +976,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// and send peer update
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: true,
ID: "policy-source-destination-peers",
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
PolicyID: "policy-source-destination-peers",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
@@ -1000,11 +1012,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// and send peer update
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: false,
ID: "policy-source-destination-peers",
AccountID: account.Id,
Enabled: false,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
PolicyID: "policy-source-destination-peers",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
@@ -1035,11 +1049,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
AccountID: account.Id,
Description: "updated description",
Enabled: false,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
PolicyID: "policy-source-destination-peers",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
@@ -1069,11 +1085,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// and send peer update
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: true,
ID: "policy-source-destination-peers",
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
PolicyID: "policy-source-destination-peers",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},

View File

@@ -2,16 +2,14 @@ package server
import (
"context"
"fmt"
"slices"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
)
const (
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
@@ -20,85 +18,127 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, err
}
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
}
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
return nil, status.NewAdminPermissionError()
}
if err := postureChecks.Validate(); err != nil {
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
}
// SavePostureChecks saves a posture check.
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return status.NewAdminPermissionError()
}
if err = am.validatePostureChecks(ctx, accountID, postureChecks); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
exists, uniqName := am.savePostureChecks(account, postureChecks)
// we do not allow create new posture checks with non uniq name
if !exists && !uniqName {
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
updateAccountPeers, err := am.arePostureCheckChangesAffectPeers(ctx, accountID, postureChecks.ID, isUpdate)
if err != nil {
return err
}
action := activity.PostureCheckCreated
if exists {
action = activity.PostureCheckUpdated
account.Network.IncSerial()
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if isUpdate {
action = activity.PostureCheckUpdated
if err = am.Store.SaveAccount(ctx, account); err != nil {
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
return fmt.Errorf("failed to get posture checks: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
}
if err = transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks); err != nil {
return fmt.Errorf("failed to save posture checks: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
am.updateAccountPeers(ctx, account)
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
func (am *DefaultAccountManager) validatePostureChecks(ctx context.Context, accountID string, postureChecks *posture.Checks) error {
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
account, err := am.Store.GetAccount(ctx, accountID)
checks, err := am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
for _, check := range checks {
if check.Name == postureChecks.Name && check.ID != postureChecks.ID {
return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name)
}
}
return nil
}
// DeletePostureChecks deletes a posture check by ID.
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
return status.NewAdminPermissionError()
}
postureChecks, err := am.deletePostureChecks(account, postureChecksID)
postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
if err != nil {
return err
}
if err = am.Store.SaveAccount(ctx, account); err != nil {
if err = am.isPostureCheckLinkedToPolicy(ctx, postureChecksID, accountID); err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID); err != nil {
return fmt.Errorf("failed to delete posture checks: %w", err)
}
return nil
})
if err != nil {
return err
}
@@ -107,132 +147,123 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
return nil
}
// ListPostureChecks returns a list of posture checks.
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
uniqName = true
for i, p := range account.PostureChecks {
if !exists && p.ID == postureChecks.ID {
account.PostureChecks[i] = postureChecks
exists = true
}
if p.Name == postureChecks.Name {
uniqName = false
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func (am *DefaultAccountManager) isPostureCheckLinkedToPolicy(ctx context.Context, postureChecksID, accountID string) error {
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
}
}
if !exists {
account.PostureChecks = append(account.PostureChecks, postureChecks)
}
return
}
func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) {
postureChecksIdx := -1
for i, postureChecks := range account.PostureChecks {
if postureChecks.ID == postureChecksID {
postureChecksIdx = i
break
}
}
if postureChecksIdx < 0 {
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
}
// Check if posture check is linked to any policy
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
}
postureChecks := account.PostureChecks[postureChecksIdx]
account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
return postureChecks, nil
return nil
}
// getPeerPostureChecks returns the posture checks applied for a given peer.
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks {
peerPostureChecks := make(map[string]posture.Checks)
if len(account.PostureChecks) == 0 {
return nil
func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) {
postureChecks, err := am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
if err != nil || len(postureChecks) == 0 {
return nil, err
}
for _, policy := range account.Policies {
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerPostureChecks := make(map[string]*posture.Checks)
for _, policy := range policies {
if !policy.Enabled {
continue
}
if isPeerInPolicySourceGroups(peer.ID, account, policy) {
addPolicyPostureChecks(account, policy, peerPostureChecks)
isInGroup, err := am.isPeerInPolicySourceGroups(ctx, accountID, peerID, policy)
if err != nil {
return nil, err
}
if isInGroup {
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID)
if err == nil {
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
}
}
}
postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks))
for _, check := range peerPostureChecks {
checkCopy := check
postureChecksList = append(postureChecksList, &checkCopy)
}
return postureChecksList
return maps.Values(peerPostureChecks), nil
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool {
func (am *DefaultAccountManager) isPeerInPolicySourceGroups(ctx context.Context, accountID, peerID string, policy *Policy) (bool, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
group, ok := account.Groups[sourceGroup]
if ok && slices.Contains(group.Peers, peerID) {
return true
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
if err != nil {
log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err)
return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
}
if slices.Contains(group.Peers, peerID) {
return true, nil
}
}
}
return false
}
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
for _, postureCheck := range account.PostureChecks {
if postureCheck.ID == sourcePostureCheckID {
peerPostureChecks[sourcePostureCheckID] = *postureCheck
}
}
}
}
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
for _, policy := range account.Policies {
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
return true, policy
}
}
return false, nil
}
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func (am *DefaultAccountManager) arePostureCheckChangesAffectPeers(ctx context.Context, accountID, postureCheckID string, exists bool) (bool, error) {
if !exists {
return false
return false, nil
}
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
if !isLinked {
return false
policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
if err != nil {
return false, err
}
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
hasPeers, err := am.anyGroupHasPeers(ctx, accountID, policy.ruleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
}
}
return false, nil
}

View File

@@ -5,8 +5,9 @@ import (
"testing"
"time"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/group"
@@ -26,41 +27,43 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestPostureChecksAccount(am)
accountID, err := initTestPostureChecksAccount(am)
if err != nil {
t.Error("failed to init testing account")
}
t.Run("Generic posture check flow", func(t *testing.T) {
// regular users can not create checks
err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
err := am.SavePostureChecks(context.Background(), accountID, regularUserID, &posture.Checks{}, false)
assert.Error(t, err)
// regular users cannot list check
_, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID)
_, err = am.ListPostureChecks(context.Background(), accountID, regularUserID)
assert.Error(t, err)
// should be possible to create posture check with uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
ID: postureCheckID,
Name: postureCheckName,
err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
ID: postureCheckID,
AccountID: accountID,
Name: postureCheckName,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.26.0",
},
},
})
}, false)
assert.NoError(t, err)
// admin users can list check
checks, err := am.ListPostureChecks(context.Background(), account.Id, adminUserID)
checks, err := am.ListPostureChecks(context.Background(), accountID, adminUserID)
assert.NoError(t, err)
assert.Len(t, checks, 1)
// should not be possible to create posture check with non uniq name
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
ID: "new-id",
Name: postureCheckName,
err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
ID: "new-id",
AccountID: accountID,
Name: postureCheckName,
Checks: posture.ChecksDefinition{
GeoLocationCheck: &posture.GeoLocationCheck{
Locations: []posture.Location{
@@ -70,57 +73,61 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
},
},
},
})
}, false)
assert.Error(t, err)
// admins can update posture checks
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
ID: postureCheckID,
Name: postureCheckName,
err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{
ID: postureCheckID,
AccountID: accountID,
Name: postureCheckName,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.27.0",
},
},
})
}, false)
assert.NoError(t, err)
// users should not be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID)
err = am.DeletePostureChecks(context.Background(), accountID, postureCheckID, regularUserID)
assert.Error(t, err)
// admin should be able to delete posture checks
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID)
err = am.DeletePostureChecks(context.Background(), accountID, postureCheckID, adminUserID)
assert.NoError(t, err)
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID)
checks, err = am.ListPostureChecks(context.Background(), accountID, adminUserID)
assert.NoError(t, err)
assert.Len(t, checks, 0)
})
}
func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) {
func initTestPostureChecksAccount(am *DefaultAccountManager) (string, error) {
accountID := "testingAccount"
domain := "example.com"
admin := &User{
Id: adminUserID,
Role: UserRoleAdmin,
}
user := &User{
Id: regularUserID,
Role: UserRoleUser,
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account.Users[admin.Id] = admin
account.Users[user.Id] = user
err := am.Store.SaveAccount(context.Background(), account)
err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain)
if err != nil {
return nil, err
return "", err
}
return am.Store.GetAccount(context.Background(), account.Id)
err = am.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
{
Id: adminUserID,
AccountID: accountID,
Role: UserRoleAdmin,
},
{
Id: regularUserID,
AccountID: accountID,
Role: UserRoleUser,
},
})
if err != nil {
return "", err
}
return accountID, nil
}
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
@@ -128,19 +135,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{
{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
ID: "groupA",
AccountID: account.Id,
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
},
{
ID: "groupB",
Name: "GroupB",
Peers: []string{},
ID: "groupB",
AccountID: account.Id,
Name: "GroupB",
Peers: []string{},
},
{
ID: "groupC",
Name: "GroupC",
Peers: []string{},
ID: "groupC",
AccountID: account.Id,
Name: "GroupC",
Peers: []string{},
},
})
assert.NoError(t, err)
@@ -169,7 +179,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, false)
assert.NoError(t, err)
select {
@@ -192,7 +202,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err)
select {
@@ -203,11 +213,13 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
})
policy := Policy{
ID: "policyA",
Enabled: true,
ID: "policyA",
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
ID: "ruleA",
PolicyID: "policyA",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
@@ -255,7 +267,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err)
select {
@@ -303,17 +315,19 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
})
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, false)
assert.NoError(t, err)
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
policy = Policy{
ID: "policyB",
Enabled: true,
ID: "policyB",
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
ID: "ruleB",
PolicyID: "policyB",
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupC"},
@@ -337,7 +351,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err)
select {
@@ -355,11 +369,13 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
})
policy = Policy{
ID: "policyB",
Enabled: true,
ID: "policyB",
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
ID: "ruleB",
PolicyID: "policyB",
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupA"},
@@ -384,7 +400,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err)
select {
@@ -398,10 +414,13 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
policy = Policy{
ID: "policyB",
Enabled: true,
ID: "policyB",
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "ruleB",
PolicyID: "policyB",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupB"},
@@ -429,7 +448,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck, true)
assert.NoError(t, err)
select {
@@ -441,79 +460,126 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
account := &Account{
Policies: []*Policy{
{
ID: "policyA",
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
},
},
SourcePostureChecks: []string{"checkA"},
},
},
Groups: map[string]*group.Group{
"groupA": {
ID: "groupA",
Peers: []string{"peer1"},
},
"groupB": {
ID: "groupB",
Peers: []string{},
},
},
PostureChecks: []*posture.Checks{
{
ID: "checkA",
},
{
ID: "checkB",
},
},
manager, err := createManager(t)
require.NoError(t, err, "failed to create account manager")
accountID, err := initTestPostureChecksAccount(manager)
require.NoError(t, err, "failed to init testing account")
groupA := &group.Group{
ID: "groupA",
AccountID: accountID,
Peers: []string{"peer1"},
}
groupB := &group.Group{
ID: "groupB",
AccountID: accountID,
Peers: []string{},
}
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
require.NoError(t, err, "failed to save groups")
policy := &Policy{
ID: "policyA",
AccountID: accountID,
Rules: []*PolicyRule{
{
ID: "ruleA",
PolicyID: "policyA",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
},
},
SourcePostureChecks: []string{"checkA"},
}
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err, "failed to save policy")
postureCheckA := &posture.Checks{
ID: "checkA",
Name: "checkA",
AccountID: accountID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckA, false)
require.NoError(t, err, "failed to save postureCheckA")
postureCheckB := &posture.Checks{
ID: "checkB",
Name: "checkB",
AccountID: accountID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
},
}
err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckB, false)
require.NoError(t, err, "failed to save postureCheckB")
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkB", true)
require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check does not exist", func(t *testing.T) {
result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "unknown", false)
require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupB"}
account.Policies[0].Rules[0].Destinations = []string{"groupA"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
policy.Rules[0].Sources = []string{"groupB"}
policy.Rules[0].Destinations = []string{"groupA"}
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err, "failed to update policy")
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"groupA"}
account.Policies[0].Rules[0].Destinations = []string{"groupB"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
policy.Rules[0].Sources = []string{"groupA"}
policy.Rules[0].Destinations = []string{"groupB"}
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err, "failed to update policy")
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
groupA.Peers = []string{}
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
require.NoError(t, err, "failed to save groups")
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
account.Groups["groupA"].Peers = []string{}
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
policy.Rules[0].Sources = []string{"nonExistentGroup"}
policy.Rules[0].Destinations = []string{"nonExistentGroup"}
err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err, "failed to update policy")
result, err := manager.arePostureCheckChangesAffectPeers(context.Background(), accountID, "checkA", true)
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, sErr.Type())
assert.False(t, result)
})
}

View File

@@ -52,17 +52,43 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func (am *DefaultAccountManager) GetRoutesByPrefixOrDomains(ctx context.Context, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
routes := make([]*route.Route, 0)
for _, r := range accountRoutes {
dynamic := r.IsDynamic()
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
!dynamic && r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes, nil
}
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, accountID, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
// routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
routesWithPrefix, err := am.GetRoutesByPrefixOrDomains(ctx, accountID, prefix, domains)
if err != nil {
return err
}
// lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool)
@@ -81,8 +107,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true
group := account.GetGroup(groupID)
if group == nil {
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil || group == nil {
return status.Errorf(
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
getRouteDescriptor(prefix, domains), groupID,
@@ -97,10 +123,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
if peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group
peer := account.GetPeer(peerID)
if peer == nil {
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
if err != nil || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
@@ -109,7 +136,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
// check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs {
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
// we validated the group existence before entering this function, no need to check again.
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil || group == nil {
return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID)
}
if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf(
@@ -120,10 +151,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(id)
if peer == nil {
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s from the group %s already has this route",
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
@@ -143,16 +175,22 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
// CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
// Do not allow non-Linux peers
if peer := account.GetPeer(peerID); peer != nil {
if peerID != "" {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return nil, err
}
if peer.Meta.GoOS != "linux" {
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
@@ -179,22 +217,28 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
var newRoute route.Route
newRoute.ID = route.ID(xid.New().String())
newRoute.AccountID = accountID
accountGroups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if len(peerGroupIDs) > 0 {
err = validateGroups(peerGroupIDs, account.Groups)
err = validateGroups(peerGroupIDs, accountGroups)
if err != nil {
return nil, err
}
}
if len(accessControlGroupIDs) > 0 {
err = validateGroups(accessControlGroupIDs, account.Groups)
err = validateGroups(accessControlGroupIDs, accountGroups)
if err != nil {
return nil, err
}
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
err = am.checkRoutePrefixOrDomainsExistForPeers(ctx, accountID, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
if err != nil {
return nil, err
}
@@ -207,7 +251,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
err = validateGroups(groups, account.Groups)
err = validateGroups(groups, accountGroups)
if err != nil {
return nil, err
}
@@ -226,30 +270,46 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
newRoute.KeepRoute = keepRoute
newRoute.AccessControlGroups = accessControlGroupIDs
if account.Routes == nil {
account.Routes = make(map[route.ID]*route.Route)
}
account.Routes[newRoute.ID] = &newRoute
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
updateAccountPeers, err := am.areRouteChangesAffectPeers(ctx, &newRoute)
if err != nil {
return nil, err
}
if isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, account)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
err = transaction.SaveRoute(ctx, LockingStrengthUpdate, &newRoute)
if err != nil {
return fmt.Errorf("failed to create route: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return &newRoute, nil
}
// SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if routeToSave == nil {
return status.Errorf(status.InvalidArgument, "route provided is nil")
@@ -263,18 +323,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
account, err := am.Store.GetAccount(ctx, accountID)
oldRoute, err := am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeToSave.ID))
if err != nil {
return err
}
// Do not allow non-Linux peers
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
}
@@ -291,72 +344,119 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
}
// Do not allow non-Linux peers
if routeToSave.Peer != "" {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, routeToSave.Peer)
if err != nil {
return err
}
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
if len(routeToSave.PeerGroups) > 0 {
err = validateGroups(routeToSave.PeerGroups, account.Groups)
err = validateGroups(routeToSave.PeerGroups, groups)
if err != nil {
return err
}
}
if len(routeToSave.AccessControlGroups) > 0 {
err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
err = validateGroups(routeToSave.AccessControlGroups, groups)
if err != nil {
return err
}
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
err = am.checkRoutePrefixOrDomainsExistForPeers(ctx, accountID, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
if err != nil {
return err
}
err = validateGroups(routeToSave.Groups, account.Groups)
err = validateGroups(routeToSave.Groups, groups)
if err != nil {
return err
}
oldRoute := account.Routes[routeToSave.ID]
account.Routes[routeToSave.ID] = routeToSave
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
oldRouteAffectsPeers, err := am.areRouteChangesAffectPeers(ctx, oldRoute)
if err != nil {
return err
}
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, account)
newRouteAffectsPeers, err := am.areRouteChangesAffectPeers(ctx, routeToSave)
if err != nil {
return err
}
routeToSave.AccountID = accountID
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
err = transaction.SaveRoute(ctx, LockingStrengthUpdate, routeToSave)
if err != nil {
return fmt.Errorf("failed to save route: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
routy := account.Routes[routeID]
if routy == nil {
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
delete(account.Routes, routeID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
route, err := am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
if err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
updateAccountPeers, err := am.areRouteChangesAffectPeers(ctx, route)
if err != nil {
return err
}
if isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, account)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.DeleteRoute(ctx, LockingStrengthUpdate, accountID, string(routeID)); err != nil {
return fmt.Errorf("failed to delete route: %w", err)
}
return nil
})
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -369,8 +469,12 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
@@ -649,8 +753,21 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
return &portInfo
}
// isRouteChangeAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
// areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers.
func (am *DefaultAccountManager) areRouteChangesAffectPeers(ctx context.Context, route *route.Route) (bool, error) {
if route.Peer != "" {
return true, nil
}
hasPeers, err := am.anyGroupHasPeers(ctx, route.AccountID, route.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return am.anyGroupHasPeers(ctx, route.AccountID, route.PeerGroups)
}

View File

@@ -5,19 +5,20 @@ import (
"fmt"
"net"
"net/netip"
"strings"
"testing"
"time"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
@@ -427,21 +428,22 @@ func TestCreateRoute(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestRouteAccount(t, am)
accountID, err := initTestRouteAccount(t, am)
if err != nil {
t.Errorf("failed to init testing account: %s", err)
}
if testCase.createInitRoute {
groupAll, errInit := account.GetGroupAll()
groupAll, errInit := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
require.NoError(t, errInit)
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
_, errInit = am.CreateRoute(context.Background(), accountID, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
require.NoError(t, errInit)
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
_, errInit = am.CreateRoute(context.Background(), accountID, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
require.NoError(t, errInit)
}
outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
outRoute, err := am.CreateRoute(context.Background(), accountID, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
testCase.errFunc(t, err)
@@ -451,6 +453,7 @@ func TestCreateRoute(t *testing.T) {
// assign generated ID
testCase.expectedRoute.ID = outRoute.ID
testCase.expectedRoute.AccountID = accountID
if !testCase.expectedRoute.IsEqual(outRoute) {
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", outRoute, testCase.expectedRoute)
@@ -917,14 +920,15 @@ func TestSaveRoute(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestRouteAccount(t, am)
accountID, err := initTestRouteAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
if testCase.createInitRoute {
account.Routes["initRoute"] = &route.Route{
initRoute := &route.Route{
ID: "initRoute",
AccountID: accountID,
Network: existingNetwork,
NetID: existingRouteID,
NetworkType: route.IPv4Network,
@@ -935,14 +939,13 @@ func TestSaveRoute(t *testing.T) {
Enabled: true,
Groups: []string{routeGroup1},
}
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, initRoute)
require.NoError(t, err, "failed to save init route")
}
account.Routes[testCase.existingRoute.ID] = testCase.existingRoute
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Error("account should be saved")
}
testCase.existingRoute.AccountID = accountID
err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, testCase.existingRoute)
require.NoError(t, err, "failed to save existing route")
var routeToSave *route.Route
@@ -977,7 +980,7 @@ func TestSaveRoute(t *testing.T) {
}
}
err = am.SaveRoute(context.Background(), account.Id, userID, routeToSave)
err = am.SaveRoute(context.Background(), accountID, userID, routeToSave)
testCase.errFunc(t, err)
@@ -985,14 +988,10 @@ func TestSaveRoute(t *testing.T) {
return
}
account, err = am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
}
savedRoute, saved := account.Routes[testCase.expectedRoute.ID]
require.True(t, saved)
savedRoute, err := am.GetRoute(context.Background(), accountID, testCase.existingRoute.ID, userID)
require.NoError(t, err, "failed to get saved route")
testCase.expectedRoute.AccountID = accountID
if !testCase.expectedRoute.IsEqual(savedRoute) {
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute)
}
@@ -1001,50 +1000,48 @@ func TestSaveRoute(t *testing.T) {
}
func TestDeleteRoute(t *testing.T) {
testingRoute := &route.Route{
ID: "testingRoute",
Network: netip.MustParsePrefix("192.168.0.0/16"),
Domains: domain.List{"domain1", "domain2"},
KeepRoute: true,
NetworkType: route.IPv4Network,
Peer: peer1Key,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
}
am, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestRouteAccount(t, am)
accountID, err := initTestRouteAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
account.Routes[testingRoute.ID] = testingRoute
err = am.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "GroupA",
AccountID: accountID,
Name: "GroupA",
})
require.NoError(t, err, "failed to save group")
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Error("failed to save account")
testingRoute := &route.Route{
Network: netip.MustParsePrefix("192.168.0.0/16"),
NetID: route.NetID("12345678901234567890qw"),
Groups: []string{"GroupA"},
KeepRoute: true,
NetworkType: route.IPv4Network,
Peer: peer1ID,
Description: "super",
Masquerade: false,
Metric: 9999,
Enabled: true,
}
createdRoute, err := am.CreateRoute(context.Background(), accountID, testingRoute.Network, testingRoute.NetworkType, testingRoute.Domains, peer1ID, []string{}, testingRoute.Description, testingRoute.NetID, testingRoute.Masquerade, testingRoute.Metric, testingRoute.Groups, testingRoute.AccessControlGroups, true, userID, testingRoute.KeepRoute)
require.NoError(t, err, "failed to create route")
err = am.DeleteRoute(context.Background(), account.Id, testingRoute.ID, userID)
err = am.DeleteRoute(context.Background(), accountID, createdRoute.ID, userID)
if err != nil {
t.Error("deleting route failed with error: ", err)
}
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Error("failed to retrieve saved account with error: ", err)
}
_, found := savedAccount.Routes[testingRoute.ID]
if found {
t.Error("route shouldn't be found after delete")
}
_, err = am.GetRoute(context.Background(), accountID, testingRoute.ID, userID)
require.NotNil(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, sErr.Type())
}
func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
@@ -1066,16 +1063,14 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestRouteAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
accountID, err := initTestRouteAccount(t, am)
require.NoError(t, err, "failed to init testing account")
newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
newRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
require.NoError(t, err)
require.Equal(t, newRoute.Enabled, true)
@@ -1091,7 +1086,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.ListGroups(context.Background(), account.Id)
groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups {
@@ -1103,21 +1098,21 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
}
}
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA1.ID, peer2ID)
err = am.GroupDeletePeer(context.Background(), accountID, groupHA1.ID, peer2ID)
require.NoError(t, err)
peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err)
assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes")
err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID)
err = am.GroupDeletePeer(context.Background(), accountID, groupHA2.ID, peer4ID)
require.NoError(t, err)
peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID)
require.NoError(t, err)
assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route")
err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID)
err = am.GroupAddPeer(context.Background(), accountID, groupHA2.ID, peer4ID)
require.NoError(t, err)
peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1128,7 +1123,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes")
err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID)
err = am.DeleteRoute(context.Background(), accountID, newRoute.ID, userID)
require.NoError(t, err)
peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1158,16 +1153,14 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
t.Error("failed to create account manager")
}
account, err := initTestRouteAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
accountID, err := initTestRouteAccount(t, am)
require.NoError(t, err, "failed to init testing account")
newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
createdRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
require.NoError(t, err)
noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1181,7 +1174,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
expectedRoute := enabledRoute.Copy()
expectedRoute.Peer = peer1Key
err = am.SaveRoute(context.Background(), account.Id, userID, enabledRoute)
err = am.SaveRoute(context.Background(), accountID, userID, enabledRoute)
require.NoError(t, err)
peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1193,7 +1186,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err)
require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group")
err = am.GroupAddPeer(context.Background(), account.Id, routeGroup1, peer2ID)
err = am.GroupAddPeer(context.Background(), accountID, routeGroup1, peer2ID)
require.NoError(t, err)
peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID)
@@ -1202,27 +1195,29 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group")
newGroup := &nbgroup.Group{
ID: xid.New().String(),
Name: "peer1 group",
Peers: []string{peer1ID},
ID: xid.New().String(),
AccountID: accountID,
Name: "peer1 group",
Peers: []string{peer1ID},
}
err = am.SaveGroup(context.Background(), account.Id, userID, newGroup)
err = am.SaveGroup(context.Background(), accountID, userID, newGroup)
require.NoError(t, err)
rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser")
policies, err := am.ListPolicies(context.Background(), accountID, "testingUser")
require.NoError(t, err)
defaultRule := rules[0]
defaultRule := policies[0]
newPolicy := defaultRule.Copy()
newPolicy.ID = xid.New().String()
newPolicy.Name = "peer1 only"
newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
err = am.SavePolicy(context.Background(), accountID, userID, newPolicy, false)
require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
err = am.DeletePolicy(context.Background(), accountID, defaultRule.ID, userID)
require.NoError(t, err)
peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1233,7 +1228,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err)
require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2")
err = am.DeleteRoute(context.Background(), account.Id, enabledRoute.ID, userID)
err = am.DeleteRoute(context.Background(), accountID, enabledRoute.ID, userID)
require.NoError(t, err)
peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1267,179 +1262,104 @@ func createRouterStore(t *testing.T) (Store, error) {
return store, nil
}
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (string, error) {
t.Helper()
accountID := "testingAcc"
domain := "example.com"
account := newAccountWithId(context.Background(), accountID, userID, domain)
err := am.Store.SaveAccount(context.Background(), account)
err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
if err != nil {
return nil, err
return "", err
}
ips := account.getTakenIPs()
peer1IP, err := AllocatePeerIP(account.Network.Net, ips)
createPeer := func(peerID, peerKey, peerName, dnsLabel, kernel, core, platform, os string) (*nbpeer.Peer, error) {
ips, err := am.Store.GetTakenIPs(context.Background(), LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
network, err := am.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerIP, err := AllocatePeerIP(network.Net, ips)
if err != nil {
return nil, err
}
peer := &nbpeer.Peer{
IP: peerIP,
AccountID: accountID,
ID: peerID,
Key: peerKey,
Name: peerName,
DNSLabel: dnsLabel,
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: peerName,
GoOS: strings.ToLower(kernel),
Kernel: kernel,
Core: core,
Platform: platform,
OS: os,
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
if err := am.Store.AddPeerToAccount(context.Background(), peer); err != nil {
return nil, err
}
return peer, nil
}
// Create peers
peer1, err := createPeer(peer1ID, peer1Key, "test-host1@netbird.io", "test-host1", "Linux", "21.04", "x86_64", "Ubuntu")
if err != nil {
return nil, err
return "", err
}
peer1 := &nbpeer.Peer{
IP: peer1IP,
ID: peer1ID,
Key: peer1Key,
Name: "test-host1@netbird.io",
DNSLabel: "test-host1",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host1@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer1.ID] = peer1
ips = account.getTakenIPs()
peer2IP, err := AllocatePeerIP(account.Network.Net, ips)
peer2, err := createPeer(peer2ID, peer2Key, "test-host2@netbird.io", "test-host2", "Linux", "21.04", "x86_64", "Ubuntu")
if err != nil {
return nil, err
return "", err
}
peer2 := &nbpeer.Peer{
IP: peer2IP,
ID: peer2ID,
Key: peer2Key,
Name: "test-host2@netbird.io",
DNSLabel: "test-host2",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host2@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer2.ID] = peer2
ips = account.getTakenIPs()
peer3IP, err := AllocatePeerIP(account.Network.Net, ips)
peer3, err := createPeer(peer3ID, peer3Key, "test-host3@netbird.io", "test-host3", "Darwin", "13.4.1", "arm64", "darwin")
if err != nil {
return nil, err
return "", err
}
peer3 := &nbpeer.Peer{
IP: peer3IP,
ID: peer3ID,
Key: peer3Key,
Name: "test-host3@netbird.io",
DNSLabel: "test-host3",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host3@netbird.io",
GoOS: "darwin",
Kernel: "Darwin",
Core: "13.4.1",
Platform: "arm64",
OS: "darwin",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer3.ID] = peer3
ips = account.getTakenIPs()
peer4IP, err := AllocatePeerIP(account.Network.Net, ips)
peer4, err := createPeer(peer4ID, peer4Key, "test-host4@netbird.io", "test-host4", "Linux", "21.04", "x86_64", "Ubuntu")
if err != nil {
return nil, err
return "", err
}
peer4 := &nbpeer.Peer{
IP: peer4IP,
ID: peer4ID,
Key: peer4Key,
Name: "test-host4@netbird.io",
DNSLabel: "test-host4",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host4@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer4.ID] = peer4
ips = account.getTakenIPs()
peer5IP, err := AllocatePeerIP(account.Network.Net, ips)
peer5, err := createPeer(peer5ID, peer5Key, "test-host5@netbird.io", "test-host5", "Linux", "21.04", "x86_64", "Ubuntu")
if err != nil {
return nil, err
return "", err
}
peer5 := &nbpeer.Peer{
IP: peer5IP,
ID: peer5ID,
Key: peer5Key,
Name: "test-host5@netbird.io",
DNSLabel: "test-host5",
UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host5@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
groupAll, err := am.GetGroupByName(context.Background(), "All", accountID)
if err != nil {
return "", err
}
account.Peers[peer5.ID] = peer5
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
groupAll, err := account.GetGroupAll()
if err != nil {
return nil, err
}
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID)
if err != nil {
return nil, err
return "", err
}
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID)
if err != nil {
return nil, err
return "", err
}
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID)
if err != nil {
return nil, err
return "", err
}
err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID)
if err != nil {
return nil, err
return "", err
}
newGroup := []*nbgroup.Group{
newGroups := []*nbgroup.Group{
{
ID: routeGroup1,
Name: routeGroup1,
@@ -1471,15 +1391,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
Peers: []string{peer1.ID, peer4.ID},
},
}
for _, group := range newGroup {
err = am.SaveGroup(context.Background(), accountID, userID, group)
if err != nil {
return nil, err
}
err = am.SaveGroups(context.Background(), accountID, userID, newGroups)
if err != nil {
return "", err
}
return am.Store.GetAccount(context.Background(), account.Id)
return accountID, nil
}
func TestAccount_getPeersRoutesFirewall(t *testing.T) {
@@ -1783,10 +1700,10 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
manager, err := createRouterManager(t)
require.NoError(t, err, "failed to create account manager")
account, err := initTestRouteAccount(t, manager)
accountID, err := initTestRouteAccount(t, manager)
require.NoError(t, err, "failed to init testing account")
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{
{
ID: "groupA",
Name: "GroupA",
@@ -1832,7 +1749,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}()
_, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute,
)
@@ -1868,7 +1785,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}()
_, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute,
)
@@ -1904,7 +1821,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}()
newRoute, err := manager.CreateRoute(
context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute,
)
@@ -1928,7 +1845,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
err := manager.SaveRoute(context.Background(), accountID, userID, &baseRoute)
require.NoError(t, err)
select {
@@ -1946,7 +1863,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID)
err := manager.DeleteRoute(context.Background(), accountID, baseRoute.ID, userID)
require.NoError(t, err)
select {
@@ -1970,7 +1887,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
Groups: []string{routeGroup1},
}
_, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
)
@@ -1982,7 +1899,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "groupB",
Name: "GroupB",
Peers: []string{peer1ID},
@@ -2010,7 +1927,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
Groups: []string{"groupC"},
}
_, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
)
@@ -2022,7 +1939,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
close(done)
}()
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "groupC",
Name: "GroupC",
Peers: []string{peer1ID},

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"fmt"
"hash/fnv"
"strconv"
"strings"
@@ -12,6 +11,7 @@ import (
"unicode/utf8"
"github.com/google/uuid"
nbgroup "github.com/netbirdio/netbird/management/server/group"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
@@ -226,34 +226,44 @@ func Hash(s string) uint32 {
// and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil {
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if err = validateSetupKeyAutoGroups(groups, autoGroups); err != nil {
return nil, err
}
setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
account.SetupKeys[setupKey.Key] = setupKey
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, status.Errorf(status.Internal, "failed adding account key")
setupKey.AccountID = accountID
if err = am.Store.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey); err != nil {
return nil, err
}
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta())
groupMap := make(map[string]*nbgroup.Group, len(groups))
for _, g := range groups {
groupMap[g.ID] = g
}
for _, g := range setupKey.AutoGroups {
group := account.GetGroup(g)
group := groupMap[g]
if group != nil {
am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, accountID)
}
}
@@ -268,30 +278,30 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
// (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if keyToSave == nil {
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
}
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
var oldKey *SetupKey
for _, key := range account.SetupKeys {
if key.Id == keyToSave.Id {
oldKey = key.Copy()
break
}
}
if oldKey == nil {
return nil, status.Errorf(status.NotFound, "setup key not found")
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil {
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if err = validateSetupKeyAutoGroups(groups, keyToSave.AutoGroups); err != nil {
return nil, err
}
oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
if err != nil {
return nil, err
}
@@ -302,9 +312,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now().UTC()
account.SetupKeys[newKey.Key] = newKey
if err = am.Store.SaveAccount(ctx, account); err != nil {
if err = am.Store.SaveSetupKey(ctx, LockingStrengthUpdate, newKey); err != nil {
return nil, err
}
@@ -315,24 +323,30 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
defer func() {
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
groupMap := make(map[string]*nbgroup.Group, len(groups))
for _, g := range groups {
groupMap[g.ID] = g
}
for _, g := range removedGroups {
group := account.GetGroup(g)
group := groupMap[g]
if group != nil {
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, accountID)
}
}
for _, g := range addedGroups {
group := account.GetGroup(g)
group := groupMap[g]
if group != nil {
am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey,
map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name})
} else {
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id)
log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, accountID)
}
}
}()
@@ -347,8 +361,12 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.NewUnauthorizedToViewSetupKeysError()
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
@@ -366,8 +384,12 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.NewUnauthorizedToViewSetupKeysError()
if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return nil, status.NewAdminPermissionError()
}
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
@@ -387,21 +409,25 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
return err
}
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return status.NewUnauthorizedToViewSetupKeysError()
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}
deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
if err != nil {
return fmt.Errorf("failed to get setup key: %w", err)
return err
}
err = am.Store.DeleteSetupKey(ctx, accountID, keyID)
err = am.Store.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID)
if err != nil {
return fmt.Errorf("failed to delete setup key: %w", err)
return err
}
am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta())
@@ -409,15 +435,22 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
return nil
}
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
for _, group := range autoGroups {
g, ok := account.Groups[group]
if !ok {
return status.Errorf(status.NotFound, "group %s doesn't exist", group)
func validateSetupKeyAutoGroups(groups []*nbgroup.Group, autoGroups []string) error {
groupMap := make(map[string]*nbgroup.Group, len(groups))
for _, g := range groups {
groupMap[g.ID] = g
}
for _, groupID := range autoGroups {
g, exists := groupMap[groupID]
if !exists {
return status.Errorf(status.NotFound, "group %s doesn't exist", groupID)
}
if g.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key")
return status.Errorf(status.InvalidArgument, "can't add 'All' group to the setup key")
}
}
return nil
}

View File

@@ -25,21 +25,21 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
}
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
if err != nil {
t.Fatal(err)
}
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
require.NoError(t, err, "failed to get or create account ID")
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{
{
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
ID: "group_1",
AccountID: accountID,
Name: "group_name_1",
Peers: []string{},
},
{
ID: "group_2",
Name: "group_name_2",
Peers: []string{},
ID: "group_2",
AccountID: accountID,
Name: "group_name_2",
Peers: []string{},
},
})
if err != nil {
@@ -49,7 +49,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
expiresIn := time.Hour
keyName := "my-test-key"
key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{},
key, err := manager.CreateSetupKey(context.Background(), accountID, keyName, SetupKeyReusable, expiresIn, []string{},
SetupKeyUnlimitedUsage, userID, false)
if err != nil {
t.Fatal(err)
@@ -58,7 +58,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
autoGroups := []string{"group_1", "group_2"}
newKeyName := "my-new-test-key"
revoked := true
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
newKey, err := manager.SaveSetupKey(context.Background(), accountID, &SetupKey{
Id: key.Id,
Name: newKeyName,
Revoked: revoked,
@@ -72,22 +72,22 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
key.Id, time.Now().UTC(), autoGroups, true)
// check the corresponding events that should have been generated
ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked)
ev := getEvent(t, accountID, manager, activity.SetupKeyRevoked)
assert.NotNil(t, ev)
assert.Equal(t, account.Id, ev.AccountID)
assert.Equal(t, accountID, ev.AccountID)
assert.Equal(t, newKeyName, ev.Meta["name"])
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
assert.NotEmpty(t, ev.Meta["key"])
assert.Equal(t, userID, ev.InitiatorID)
assert.Equal(t, key.Id, ev.TargetID)
groupAll, err := account.GetGroupAll()
assert.NoError(t, err)
groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID)
require.NoError(t, err)
// saving setup key with All group assigned to auto groups should return error
autoGroups = append(autoGroups, groupAll.ID)
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
_, err = manager.SaveSetupKey(context.Background(), accountID, &SetupKey{
Id: key.Id,
Name: newKeyName,
Revoked: revoked,
@@ -103,31 +103,31 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
}
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
if err != nil {
t.Fatal(err)
}
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
require.NoError(t, err, "failed to get or create account ID")
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_1",
AccountID: accountID,
Name: "group_name_1",
Peers: []string{},
})
if err != nil {
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "group_2",
Name: "group_name_2",
Peers: []string{},
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_2",
AccountID: accountID,
Name: "group_name_2",
Peers: []string{},
})
if err != nil {
t.Fatal(err)
}
groupAll, err := account.GetGroupAll()
assert.NoError(t, err)
groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID)
require.NoError(t, err)
type testCase struct {
name string
@@ -170,7 +170,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
t.Run(tCase.name, func(t *testing.T) {
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
key, err := manager.CreateSetupKey(context.Background(), accountID, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)
if tCase.expectedFailure {
@@ -189,10 +189,10 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
tCase.expectedUpdatedAt, tCase.expectedGroups, false)
// check the corresponding events that should have been generated
ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated)
ev := getEvent(t, accountID, manager, activity.SetupKeyCreated)
assert.NotNil(t, ev)
assert.Equal(t, account.Id, ev.AccountID)
assert.Equal(t, accountID, ev.AccountID)
assert.Equal(t, tCase.expectedKeyName, ev.Meta["name"])
assert.Equal(t, tCase.expectedType, fmt.Sprint(ev.Meta["type"]))
assert.NotEmpty(t, ev.Meta["key"])
@@ -208,12 +208,10 @@ func TestGetSetupKeys(t *testing.T) {
}
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
if err != nil {
t.Fatal(err)
}
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
require.NoError(t, err, "failed to get or create account ID")
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_1",
Name: "group_name_1",
Peers: []string{},
@@ -222,7 +220,7 @@ func TestGetSetupKeys(t *testing.T) {
t.Fatal(err)
}
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{
ID: "group_2",
Name: "group_name_2",
Peers: []string{},
@@ -384,20 +382,24 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
ID: "groupA",
AccountID: account.Id,
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
assert.NoError(t, err)
policy := Policy{
ID: "policy",
Enabled: true,
ID: "policy",
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "Rule",
PolicyID: "policy",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"group"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},

File diff suppressed because it is too large Load Diff

View File

@@ -68,20 +68,27 @@ func TestSqlite_SaveAccount_Large(t *testing.T) {
func runLargeTest(t *testing.T, store Store) {
t.Helper()
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
groupALL, err := account.GetGroupAll()
if err != nil {
t.Fatal(err)
}
accountID := "account_id"
err := newAccountWithId(context.Background(), store, accountID, "testuser", "")
assert.NoError(t, err, "failed to create account")
groupAll, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
assert.NoError(t, err, "failed to get group All")
setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
setupKey.AccountID = accountID
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
assert.NoError(t, err, "failed to save setup key")
const numPerAccount = 6000
for n := 0; n < numPerAccount; n++ {
netIP := randomIPv4()
peerID := fmt.Sprintf("%s-peer-%d", account.Id, n)
peerID := fmt.Sprintf("%s-peer-%d", accountID, n)
peer := &nbpeer.Peer{
ID: peerID,
AccountID: accountID,
Key: peerID,
IP: netIP,
Name: peerID,
@@ -90,16 +97,21 @@ func runLargeTest(t *testing.T, store Store) {
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
account.Users[user.Id] = user
err = store.AddPeerToAccount(context.Background(), peer)
assert.NoError(t, err, "failed to add peer")
err = store.AddPeerToAllGroup(context.Background(), accountID, peerID)
assert.NoError(t, err, "failed to add peer to all group")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: fmt.Sprintf("%s-user-%d", accountID, n),
AccountID: accountID,
})
assert.NoError(t, err, "failed to save user")
route := &route2.Route{
ID: route2.ID(fmt.Sprintf("network-id-%d", n)),
AccountID: accountID,
Description: "base route",
NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)),
Network: netip.MustParsePrefix(netIP.String() + "/24"),
@@ -107,22 +119,24 @@ func runLargeTest(t *testing.T, store Store) {
Metric: 9999,
Masquerade: false,
Enabled: true,
Groups: []string{groupALL.ID},
Groups: []string{groupAll.ID},
}
account.Routes[route.ID] = route
err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route)
assert.NoError(t, err, "failed to save route")
group = &nbgroup.Group{
group := &nbgroup.Group{
ID: fmt.Sprintf("group-id-%d", n),
AccountID: account.Id,
AccountID: accountID,
Name: fmt.Sprintf("group-id-%d", n),
Issued: "api",
Peers: nil,
}
account.Groups[group.ID] = group
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
assert.NoError(t, err, "failed to save group")
nameserver := &nbdns.NameServerGroup{
ID: fmt.Sprintf("nameserver-id-%d", n),
AccountID: account.Id,
AccountID: accountID,
Name: fmt.Sprintf("nameserver-id-%d", n),
Description: "",
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}},
@@ -132,20 +146,20 @@ func runLargeTest(t *testing.T, store Store) {
Enabled: false,
SearchDomainsEnabled: false,
}
account.NameServerGroups[nameserver.ID] = nameserver
err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameserver)
assert.NoError(t, err, "failed to save nameserver group")
setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
setupKey, _ = GenerateDefaultSetupKey()
setupKey.AccountID = accountID
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
assert.NoError(t, err, "failed to save setup key")
}
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(context.Background(), account.Id)
a, err := store.GetAccount(context.Background(), accountID)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
@@ -213,41 +227,53 @@ func TestSqlite_SaveAccount(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
require.NoError(t, err)
accountID := "account_id"
err = newAccountWithId(context.Background(), store, accountID, "testuser", "")
require.NoError(t, err, "failed to create account")
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
setupKey.AccountID = accountID
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
ID: "testpeer",
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
AccountID: accountID,
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
})
require.NoError(t, err, "failed to save peer")
accountID2 := "account_id2"
err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "")
require.NoError(t, err, "failed to create account")
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
setupKey, _ = GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey
account2.Peers["testpeer2"] = &nbpeer.Peer{
Key: "peerkey2",
IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
setupKey.AccountID = accountID2
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.SaveAccount(context.Background(), account2)
require.NoError(t, err)
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
ID: "testpeer2",
Key: "peerkey2",
AccountID: accountID2,
IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
})
require.NoError(t, err, "failed to save peer")
if len(store.GetAllAccounts(context.Background())) != 2 {
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(context.Background(), account.Id)
a, err := store.GetAccount(context.Background(), accountID)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
@@ -288,36 +314,52 @@ func TestSqlite_DeleteAccount(t *testing.T) {
t.Cleanup(cleanUp)
assert.NoError(t, err)
accountID := "account_id"
testUserID := "testuser"
user := NewAdminUser(testUserID)
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
}}
account := newAccountWithId(context.Background(), "account_id", testUserID, "")
setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
account.Users[testUserID] = user
err = store.SaveAccount(context.Background(), account)
err = newAccountWithId(context.Background(), store, accountID, testUserID, "")
require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 1 {
setupKey, _ := GenerateDefaultSetupKey()
setupKey.AccountID = accountID
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
ID: "testpeer",
Key: "peerkey",
AccountID: accountID,
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
})
require.NoError(t, err, "failed to save peer")
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
ID: "testtoken",
UserID: testUserID,
Name: "test token",
})
require.NoError(t, err, "failed to save personal access token")
accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
require.NoError(t, err, "failed to get all account ids")
if len(accountIDs) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "failed to get account")
err = store.DeleteAccount(context.Background(), account)
require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 0 {
accountIDs, err = store.GetAllAccountIDs(context.Background(), LockingStrengthShare)
require.NoError(t, err, "failed to get all account ids after DeleteAccount()")
if len(accountIDs) != 0 {
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
}
@@ -400,7 +442,7 @@ func TestSqlite_SavePeer(t *testing.T) {
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
ctx := context.Background()
err = store.SavePeer(ctx, account.Id, peer)
err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
@@ -416,7 +458,7 @@ func TestSqlite_SavePeer(t *testing.T) {
updatedPeer.Status.Connected = false
updatedPeer.Meta.Hostname = "updatedpeer"
err = store.SavePeer(ctx, account.Id, updatedPeer)
err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, updatedPeer)
require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
@@ -442,7 +484,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
// save status of non-existing peer
newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
@@ -461,7 +503,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus)
require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
@@ -472,7 +514,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
newStatus.Connected = true
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus)
require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
@@ -507,7 +549,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
Meta: nbpeer.PeerSystemMeta{},
}
// error is expected as peer is not in store yet
err = store.SavePeerLocation(account.Id, peer)
err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer)
assert.Error(t, err)
account.Peers[peer.ID] = peer
@@ -519,7 +561,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
peer.Location.CityName = "Berlin"
peer.Location.GeoNameID = 2950159
err = store.SavePeerLocation(account.Id, account.Peers[peer.ID])
err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, account.Peers[peer.ID])
assert.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
@@ -529,7 +571,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
assert.Equal(t, peer.Location, actual)
peer.ID = "non-existing-peer"
err = store.SavePeerLocation(account.Id, peer)
err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
@@ -572,11 +614,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
require.NoError(t, err)
require.Equal(t, id, token)
require.Equal(t, id, pat.ID)
_, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash")
_, err = store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "non-existing-hash")
require.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
@@ -595,11 +637,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id)
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
require.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id)
_, err = store.GetUserByTokenID(context.Background(), "non-existing-id")
_, err = store.GetUserByPATID(context.Background(), LockingStrengthShare, "non-existing-id")
require.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
@@ -714,19 +756,28 @@ func newSqliteStore(t *testing.T) *SqlStore {
}
func newAccount(store Store, id int) error {
str := fmt.Sprintf("%s-%d", uuid.New().String(), id)
account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com")
accountID := fmt.Sprintf("%s-%d", uuid.New().String(), id)
userID := accountID + "-testuser"
err := newAccountWithId(context.Background(), store, accountID, userID, "example.com")
if err != nil {
return err
}
setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["p"+str] = &nbpeer.Peer{
Key: "peerkey" + str,
setupKey.AccountID = accountID
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
if err != nil {
return err
}
return store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{
Key: accountID + "-peerkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
return store.SaveAccount(context.Background(), account)
})
}
func TestPostgresql_NewStore(t *testing.T) {
@@ -754,39 +805,56 @@ func TestPostgresql_SaveAccount(t *testing.T) {
t.Cleanup(cleanUp)
assert.NoError(t, err)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
accountID := "account_id"
err = newAccountWithId(context.Background(), store, accountID, "testuser", "")
require.NoError(t, err, "failed to create account")
setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
setupKey.AccountID = accountID
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
ID: "testpeer",
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
AccountID: accountID,
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
})
require.NoError(t, err, "failed to save peer")
accountID2 := "account_id2"
err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "")
require.NoError(t, err, "failed to create account")
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
setupKey, _ = GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey
account2.Peers["testpeer2"] = &nbpeer.Peer{
Key: "peerkey2",
IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
setupKey.AccountID = accountID2
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.SaveAccount(context.Background(), account2)
require.NoError(t, err)
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
ID: "testpeer2",
Key: "peerkey2",
AccountID: accountID2,
IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
})
require.NoError(t, err, "failed to save peer")
if len(store.GetAllAccounts(context.Background())) != 2 {
accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthUpdate)
require.NoError(t, err, "failed to get all account ids")
if len(accountIDs) != 2 {
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(context.Background(), account.Id)
a, err := store.GetAccount(context.Background(), accountID)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
@@ -827,32 +895,51 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
t.Cleanup(cleanUp)
assert.NoError(t, err)
accountID := "account_id"
testUserID := "testuser"
user := NewAdminUser(testUserID)
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
}}
account := newAccountWithId(context.Background(), "account_id", testUserID, "")
err = newAccountWithId(context.Background(), store, accountID, testUserID, "")
require.NoError(t, err, "failed to create account")
setupKey, _ := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
account.Users[testUserID] = user
setupKey.AccountID = accountID
err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey)
require.NoError(t, err, "failed to save setup key")
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{
ID: "testingpeer",
AccountID: accountID,
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
})
require.NoError(t, err, "failed to save peer")
if len(store.GetAllAccounts(context.Background())) != 1 {
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
ID: "testtoken",
UserID: testUserID,
Name: "test token",
})
require.NoError(t, err, "failed to save personal access token")
accountIDs, err := store.GetAllAccountIDs(context.Background(), LockingStrengthUpdate)
require.NoError(t, err, "failed to get all account ids")
if len(accountIDs) != 1 {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "failed to get account")
err = store.DeleteAccount(context.Background(), account)
require.NoError(t, err)
@@ -908,7 +995,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
// save status of non-existing peer
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus)
assert.Error(t, err)
// save new status of existing peer
@@ -924,7 +1011,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus)
require.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
@@ -967,9 +1054,9 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
require.NoError(t, err)
require.Equal(t, id, token)
require.Equal(t, id, pat.ID)
}
func TestPostgresql_GetUserByTokenID(t *testing.T) {
@@ -984,7 +1071,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id)
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
}
@@ -1047,7 +1134,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
_, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
labels, err := store.GetAccountPeerDNSLabels(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{}, labels)
@@ -1059,7 +1146,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
err = store.AddPeerToAccount(context.Background(), peer1)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
labels, err = store.GetAccountPeerDNSLabels(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test"}, labels)
@@ -1071,7 +1158,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
err = store.AddPeerToAccount(context.Background(), peer2)
require.NoError(t, err)
labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
labels, err = store.GetAccountPeerDNSLabels(context.Background(), LockingStrengthShare, existingAccountID)
require.NoError(t, err)
assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
}
@@ -1181,7 +1268,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Fatal("failed to save group")
return err
}
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
if err != nil {
t.Fatal("failed to get group")
return err
@@ -1201,7 +1288,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
users, err := store.GetAccountUsers(context.Background(), accountID)
users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Len(t, users, len(account.Users))
}
@@ -1218,7 +1305,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
domain := "example.com"
category := "public"
IsDomainPrimaryAccount := false
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount)
require.NoError(t, err)
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
@@ -1232,7 +1319,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
domain := "test.com"
category := "private"
IsDomainPrimaryAccount := true
err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount)
err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount)
require.NoError(t, err)
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
@@ -1246,7 +1333,9 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
domain := "test.com"
category := "private"
IsDomainPrimaryAccount := true
err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount)
err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, "non-existing-account-id",
domain, category, &IsDomainPrimaryAccount,
)
require.Error(t, err)
})
@@ -1274,7 +1363,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID)
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID)
require.NoError(t, err)
_, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID)
@@ -1290,6 +1379,6 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
nonExistingKeyID := "non-existing-key-id"
err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID)
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
require.Error(t, err)
}

View File

@@ -83,8 +83,8 @@ func NewPeerNotFoundError(peerKey string) error {
}
// NewAccountNotFoundError creates a new Error with NotFound type for a missing account
func NewAccountNotFoundError(accountKey string) error {
return Errorf(NotFound, "account not found: %s", accountKey)
func NewAccountNotFoundError() error {
return Errorf(NotFound, "account not found")
}
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
@@ -102,23 +102,38 @@ func NewPeerLoginExpiredError() error {
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
}
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError(err error) error {
return Errorf(NotFound, "setup key not found: %s", err)
}
func NewGetAccountFromStoreError(err error) error {
return Errorf(Internal, "issue getting account from store: %s", err)
}
func NewUnauthorizedToViewAccountSettingError() error {
return Errorf(PermissionDenied, "only users with admin power can view account settings")
}
// NewUserNotPartOfAccountError creates a new Error with PermissionDenied type for a user not being part of an account
func NewUserNotPartOfAccountError() error {
return Errorf(PermissionDenied, "user is not part of this account")
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store")
}
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
func NewStoreContextCanceledError(duration time.Duration) error {
return Errorf(Internal, "store access: context canceled after %v", duration)
// NewAdminPermissionError creates a new Error with PermissionDenied type for actions requiring admin role.
func NewAdminPermissionError() error {
return Errorf(PermissionDenied, "admin role required to perform this action")
}
// NewOwnerDeletePermissionError creates a new Error with PermissionDenied type for attempting
// to delete a user with the owner role.
func NewOwnerDeletePermissionError() error {
return Errorf(PermissionDenied, "can't delete a user with the owner role")
}
// NewServiceUserRoleInvalidError creates a new Error with InvalidArgument type for creating a service user with owner role
func NewServiceUserRoleInvalidError() error {
return Errorf(InvalidArgument, "can't create a service user with owner role")
}
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
@@ -126,7 +141,24 @@ func NewInvalidKeyIDError() error {
return Errorf(InvalidArgument, "invalid key ID")
}
// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key
func NewUnauthorizedToViewSetupKeysError() error {
return Errorf(Unauthorized, "only users with admin power can view setup keys")
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
func NewSetupKeyNotFoundError(err error) error {
return Errorf(NotFound, "setup key not found: %s", err)
}
func NewPATNotFoundError() error {
return Errorf(NotFound, "PAT not found")
}
func NewGetPATFromStoreError() error {
return Errorf(Internal, "issue getting pat from store")
}
func NewUnauthorizedToViewNSGroupsError() error {
return Errorf(PermissionDenied, "only users with admin power can view name server groups")
}
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
func NewStoreContextCanceledError(duration time.Duration) error {
return Errorf(Internal, "store access: context canceled after %v", duration)
}

View File

@@ -47,65 +47,95 @@ type Store interface {
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAllAccountIDs(ctx context.Context, lockStrength LockingStrength) ([]string, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(userID string) (string, error)
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error
CreateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain *bool) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
GetAccountPeerDNSLabels(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) (*dns.NameServerGroup, error)
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error)
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string
@@ -124,7 +154,6 @@ type Store interface {
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
DeleteSetupKey(ctx context.Context, accountID, keyID string) error
}
type StoreEngine string

View File

@@ -32,4 +32,7 @@ INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-3465300
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfvprsrlo1hqoo49ohog", "cg3161rlo1hs9cq94gdg", "cg05lnblo1hkg2j514p0"]',0,'');
INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]');
INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL);
INSERT INTO installations VALUES(1,'');

File diff suppressed because it is too large Load Diff

View File

@@ -43,37 +43,34 @@ const (
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
}
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
newPAT, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
if err != nil {
t.Fatalf("Error when adding PAT to user: %s", err)
}
assert.Equal(t, pat.CreatedBy, mockUserID)
assert.Equal(t, newPAT.CreatedBy, mockUserID)
tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken)
pat, err := am.Store.GetPATByHashedToken(context.Background(), LockingStrengthShare, newPAT.HashedToken)
if err != nil {
t.Fatalf("Error when getting token ID by hashed token: %s", err)
}
if tokenID == "" {
if pat.ID == "" {
t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
}
assert.Equal(t, pat.ID, tokenID)
assert.Equal(t, newPAT.ID, pat.ID)
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, pat.ID)
if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err)
}
@@ -84,15 +81,16 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockTargetUserId,
AccountID: mockAccountID,
IsServiceUser: false,
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
})
assert.NoError(t, err, "failed to create user")
am := DefaultAccountManager{
Store: store,
@@ -106,15 +104,16 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockTargetUserId,
AccountID: mockAccountID,
IsServiceUser: true,
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
})
assert.NoError(t, err, "failed to create user")
am := DefaultAccountManager{
Store: store,
@@ -132,12 +131,9 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
am := DefaultAccountManager{
Store: store,
@@ -151,12 +147,9 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
am := DefaultAccountManager{
Store: store,
@@ -164,26 +157,22 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
}
_, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn)
assert.Errorf(t, err, "Wrong expiration should thorw error")
assert.Errorf(t, err, "Wrong expiration should throw error")
}
func TestUser_DeletePAT(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
Id: mockUserID,
PATs: map[string]*PersonalAccessToken{
mockTokenID1: {
ID: mockTokenID1,
HashedToken: mockToken1,
},
},
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
ID: mockTokenID1,
UserID: mockUserID,
HashedToken: mockToken1,
})
assert.NoError(t, err, "failed to create PAT")
am := DefaultAccountManager{
Store: store,
@@ -195,7 +184,7 @@ func TestUser_DeletePAT(t *testing.T) {
t.Fatalf("Error when adding PAT to user: %s", err)
}
account, err = store.GetAccount(context.Background(), mockAccountID)
account, err := store.GetAccount(context.Background(), mockAccountID)
if err != nil {
t.Fatalf("Error when getting account: %s", err)
}
@@ -206,21 +195,16 @@ func TestUser_DeletePAT(t *testing.T) {
func TestUser_GetPAT(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
Id: mockUserID,
AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{
mockTokenID1: {
ID: mockTokenID1,
HashedToken: mockToken1,
},
},
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
ID: mockTokenID1,
UserID: mockUserID,
HashedToken: mockToken1,
})
assert.NoError(t, err, "failed to create PAT")
am := DefaultAccountManager{
Store: store,
@@ -239,25 +223,23 @@ func TestUser_GetPAT(t *testing.T) {
func TestUser_GetAllPATs(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
Id: mockUserID,
AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{
mockTokenID1: {
ID: mockTokenID1,
HashedToken: mockToken1,
},
mockTokenID2: {
ID: mockTokenID2,
HashedToken: mockToken2,
},
},
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
ID: mockTokenID1,
UserID: mockUserID,
HashedToken: mockToken1,
})
assert.NoError(t, err, "failed to create PAT")
err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{
ID: mockTokenID2,
UserID: mockUserID,
HashedToken: mockToken2,
})
assert.NoError(t, err, "failed to create PAT")
am := DefaultAccountManager{
Store: store,
@@ -342,12 +324,9 @@ func validateStruct(s interface{}) (err error) {
func TestUser_CreateServiceUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
am := DefaultAccountManager{
Store: store,
@@ -359,7 +338,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
t.Fatalf("Error when creating service user: %s", err)
}
account, err = store.GetAccount(context.Background(), mockAccountID)
account, err := store.GetAccount(context.Background(), mockAccountID)
assert.NoError(t, err)
assert.Equal(t, 2, len(account.Users))
@@ -383,12 +362,9 @@ func TestUser_CreateServiceUser(t *testing.T) {
func TestUser_CreateUser_ServiceUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
am := DefaultAccountManager{
Store: store,
@@ -406,7 +382,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
t.Fatalf("Error when creating user: %s", err)
}
account, err = store.GetAccount(context.Background(), mockAccountID)
account, err := store.GetAccount(context.Background(), mockAccountID)
assert.NoError(t, err)
assert.True(t, user.IsServiceUser)
@@ -425,12 +401,9 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
func TestUser_CreateUser_RegularUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
am := DefaultAccountManager{
Store: store,
@@ -450,12 +423,9 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
func TestUser_InviteNewUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
am := DefaultAccountManager{
Store: store,
@@ -549,13 +519,12 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := newStore(t)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = tt.serviceUser
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, tt.serviceUser)
assert.NoError(t, err, "failed to create service user")
am := DefaultAccountManager{
Store: store,
@@ -582,12 +551,9 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
func TestUser_DeleteUser_SelfDelete(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
am := DefaultAccountManager{
Store: store,
@@ -603,39 +569,38 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
func TestUser_DeleteUser_regularUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
targetId = "user5"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleOwner,
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
{
Id: "user2",
AccountID: mockAccountID,
IsServiceUser: true,
ServiceUserName: "user2username",
},
{
Id: "user3",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
},
{
Id: "user4",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedIntegration,
},
{
Id: "user5",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleOwner,
},
})
assert.NoError(t, err, "failed to save users")
am := DefaultAccountManager{
Store: store,
@@ -685,61 +650,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
func TestUser_DeleteUser_RegularUsers(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
targetId = "user5"
account.Users[targetId] = &User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleOwner,
}
account.Users["user6"] = &User{
Id: "user6",
IsServiceUser: false,
Issued: UserIssuedAPI,
}
account.Users["user7"] = &User{
Id: "user7",
IsServiceUser: false,
Issued: UserIssuedAPI,
}
account.Users["user8"] = &User{
Id: "user8",
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
}
account.Users["user9"] = &User{
Id: "user9",
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{
{
Id: "user2",
AccountID: mockAccountID,
IsServiceUser: true,
ServiceUserName: "user2username",
},
{
Id: "user3",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
},
{
Id: "user4",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedIntegration,
},
{
Id: "user5",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleOwner,
},
{
Id: "user6",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
},
{
Id: "user7",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
},
{
Id: "user8",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
},
{
Id: "user9",
AccountID: mockAccountID,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
},
})
assert.NoError(t, err)
am := DefaultAccountManager{
Store: store,
@@ -816,7 +784,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err)
}
acc, err := am.Store.GetAccount(context.Background(), account.Id)
acc, err := am.Store.GetAccount(context.Background(), mockAccountID)
assert.NoError(t, err)
for _, id := range tc.expectedDeleted {
@@ -836,12 +804,9 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
func TestDefaultAccountManager_GetUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
am := DefaultAccountManager{
Store: store,
@@ -865,14 +830,19 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
func TestDefaultAccountManager_ListUsers(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewRegularUser("normal_user1")
account.Users["normal_user2"] = NewRegularUser("normal_user2")
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
newUser := NewRegularUser("normal_user1")
newUser.AccountID = mockAccountID
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
assert.NoError(t, err, "failed to create user")
newUser = NewRegularUser("normal_user2")
newUser.AccountID = mockAccountID
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
assert.NoError(t, err, "failed to create user")
am := DefaultAccountManager{
Store: store,
@@ -946,15 +916,24 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
store := newStore(t)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
delete(account.Users, mockUserID)
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
newUser := NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser)
assert.NoError(t, err, "failed to create user")
settings, err := store.GetAccountSettings(context.Background(), LockingStrengthShare, mockAccountID)
assert.NoError(t, err, "failed to get account settings")
settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, mockAccountID, settings)
assert.NoError(t, err, "failed to save account settings")
err = store.DeleteUser(context.Background(), LockingStrengthUpdate, mockAccountID, mockUserID)
assert.NoError(t, err, "failed to delete user")
am := DefaultAccountManager{
Store: store,
@@ -968,7 +947,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
assert.Equal(t, 1, len(users))
userInfo, _ := users[0].ToUserInfo(nil, account.Settings)
userInfo, _ := users[0].ToUserInfo(nil, settings)
assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView)
})
}
@@ -978,22 +957,21 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
func TestDefaultAccountManager_ExternalCache(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
externalUser := &User{
Id: "externalUser",
Role: UserRoleUser,
Issued: UserIssuedIntegration,
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: "externalUser",
AccountID: mockAccountID,
Role: UserRoleUser,
Issued: UserIssuedIntegration,
IntegrationReference: integration_reference.IntegrationReference{
ID: 1,
IntegrationType: "external",
},
}
account.Users[externalUser.Id] = externalUser
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
})
assert.NoError(t, err, "failed to create user")
am := DefaultAccountManager{
Store: store,
@@ -1013,6 +991,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
assert.NoError(t, err)
cacheManager := am.GetExternalCacheManager()
externalUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, "externalUser")
assert.NoError(t, err, "failed to get user")
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"})
assert.NoError(t, err)
@@ -1042,17 +1024,17 @@ func TestUser_IsAdmin(t *testing.T) {
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockServiceUserID,
AccountID: mockAccountID,
Role: "user",
IsServiceUser: true,
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
})
assert.NoError(t, err, "failed to create user")
am := DefaultAccountManager{
Store: store,
@@ -1071,17 +1053,16 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{
err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "")
assert.NoError(t, err, "failed to create account")
err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{
Id: mockServiceUserID,
AccountID: mockAccountID,
Role: "user",
IsServiceUser: true,
}
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
})
assert.NoError(t, err, "failed to create user")
am := DefaultAccountManager{
Store: store,
@@ -1240,21 +1221,30 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// create an account and an admin user
account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io")
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), ownerUserID, "netbird.io")
if err != nil {
t.Fatal(err)
}
// create other users
account.Users[regularUserID] = NewRegularUser(regularUserID)
account.Users[adminUserID] = NewAdminUser(adminUserID)
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatal(err)
regularUser := NewRegularUser(regularUserID)
regularUser.AccountID = accountID
adminUser := NewAdminUser(adminUserID)
adminUser.AccountID = accountID
serviceUser := &User{
Id: serviceUserID,
AccountID: accountID,
IsServiceUser: true,
Role: UserRoleAdmin,
ServiceUserName: "service",
}
updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update)
err = manager.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{regularUser, adminUser, serviceUser})
assert.NoError(t, err, "failed to save users")
updated, err := manager.SaveUser(context.Background(), accountID, tc.initiatorID, tc.update)
if tc.expectedErr {
require.Errorf(t, err, "expecting SaveUser to throw an error")
} else {

View File

@@ -88,18 +88,18 @@ type Route struct {
// AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"`
// Network and Domains are mutually exclusive
Network netip.Prefix `gorm:"serializer:json"`
Domains domain.List `gorm:"serializer:json"`
KeepRoute bool
NetID NetID
Description string
Peer string
PeerGroups []string `gorm:"serializer:json"`
NetworkType NetworkType
Masquerade bool
Metric int
Enabled bool
Groups []string `gorm:"serializer:json"`
Network netip.Prefix `gorm:"serializer:json"`
Domains domain.List `gorm:"serializer:json"`
KeepRoute bool
NetID NetID
Description string
Peer string
PeerGroups []string `gorm:"serializer:json"`
NetworkType NetworkType
Masquerade bool
Metric int
Enabled bool
Groups []string `gorm:"serializer:json"`
AccessControlGroups []string `gorm:"serializer:json"`
}
@@ -111,19 +111,20 @@ func (r *Route) EventMeta() map[string]any {
// Copy copies a route object
func (r *Route) Copy() *Route {
route := &Route{
ID: r.ID,
Description: r.Description,
NetID: r.NetID,
Network: r.Network,
Domains: slices.Clone(r.Domains),
KeepRoute: r.KeepRoute,
NetworkType: r.NetworkType,
Peer: r.Peer,
PeerGroups: slices.Clone(r.PeerGroups),
Metric: r.Metric,
Masquerade: r.Masquerade,
Enabled: r.Enabled,
Groups: slices.Clone(r.Groups),
ID: r.ID,
AccountID: r.AccountID,
Description: r.Description,
NetID: r.NetID,
Network: r.Network,
Domains: slices.Clone(r.Domains),
KeepRoute: r.KeepRoute,
NetworkType: r.NetworkType,
Peer: r.Peer,
PeerGroups: slices.Clone(r.PeerGroups),
Metric: r.Metric,
Masquerade: r.Masquerade,
Enabled: r.Enabled,
Groups: slices.Clone(r.Groups),
AccessControlGroups: slices.Clone(r.AccessControlGroups),
}
return route
@@ -138,6 +139,7 @@ func (r *Route) IsEqual(other *Route) bool {
}
return other.ID == r.ID &&
other.AccountID == r.AccountID &&
other.Description == r.Description &&
other.NetID == r.NetID &&
other.Network == r.Network &&
@@ -149,7 +151,7 @@ func (r *Route) IsEqual(other *Route) bool {
other.Masquerade == r.Masquerade &&
other.Enabled == r.Enabled &&
slices.Equal(r.Groups, other.Groups) &&
slices.Equal(r.PeerGroups, other.PeerGroups)&&
slices.Equal(r.PeerGroups, other.PeerGroups) &&
slices.Equal(r.AccessControlGroups, other.AccessControlGroups)
}