From 6aa4ba7af441c9c07fd76d7946e179ea5257e975 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 7 Nov 2025 10:44:46 +0100 Subject: [PATCH] [management] incremental network map builder (#4753) --- go.mod | 2 +- management/main.go | 10 +- management/server/account.go | 36 + management/server/account/manager.go | 1 + management/server/account_test.go | 55 + management/server/dns.go | 3 + management/server/group.go | 24 + management/server/grpcserver.go | 85 +- management/server/holder.go | 39 + management/server/mock_server/account_mock.go | 14 +- management/server/nameserver.go | 9 + management/server/networkmap.go | 80 + management/server/networks/manager.go | 3 + .../server/networks/resources/manager.go | 9 + management/server/networks/routers/manager.go | 9 + management/server/peer.go | 133 +- management/server/peer_test.go | 23 +- management/server/policy.go | 6 + management/server/posture_checks.go | 3 + management/server/route.go | 9 + management/server/settings/manager.go | 8 + management/server/store/sql_store.go | 1265 ++++++++++- .../store/sql_store_get_account_test.go | 1089 ++++++++++ .../server/store/sqlstore_bench_test.go | 951 ++++++++ management/server/store/store.go | 104 +- management/server/types/account.go | 13 + management/server/types/holder.go | 43 + management/server/types/networkmap.go | 58 + .../server/types/networkmap_golden_test.go | 1069 +++++++++ management/server/types/networkmapbuilder.go | 1932 +++++++++++++++++ management/server/updatechannel.go | 6 +- management/server/user.go | 4 + route/route.go | 1 + 33 files changed, 7018 insertions(+), 78 deletions(-) create mode 100644 management/server/holder.go create mode 100644 management/server/networkmap.go create mode 100644 management/server/store/sql_store_get_account_test.go create mode 100644 management/server/store/sqlstore_bench_test.go create mode 100644 management/server/types/holder.go create mode 100644 management/server/types/networkmap.go create mode 100644 management/server/types/networkmap_golden_test.go create mode 100644 management/server/types/networkmapbuilder.go diff --git a/go.mod b/go.mod index d93d2c064..68a12908d 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 + github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 github.com/libp2p/go-netroute v0.2.1 github.com/mdlayher/socket v0.5.1 @@ -183,7 +184,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect github.com/jinzhu/inflection v1.0.0 // indirect diff --git a/management/main.go b/management/main.go index 561ed8f26..ff8482f97 100644 --- a/management/main.go +++ b/management/main.go @@ -1,11 +1,19 @@ package main import ( - "github.com/netbirdio/netbird/management/cmd" + "log" + "net/http" + // nolint:gosec + _ "net/http/pprof" "os" + + "github.com/netbirdio/netbird/management/cmd" ) func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() if err := cmd.Execute(); err != nil { os.Exit(1) } diff --git a/management/server/account.go b/management/server/account.go index dca105ddf..0aecbd586 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -53,6 +53,9 @@ const ( peerSchedulerRetryInterval = 3 * time.Second emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + + envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" + envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" ) type userLoggedInOnce bool @@ -109,6 +112,11 @@ type DefaultAccountManager struct { loginFilter *loginFilter disableDefaultPolicy bool + + holder *types.Holder + + expNewNetworkMap bool + expNewNetworkMapAIDs map[string]struct{} } func isUniqueConstraintError(err error) bool { @@ -196,6 +204,18 @@ func BuildManager( log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start)) }() + newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder)) + if err != nil { + log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err) + newNetworkMapBuilder = false + } + + ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",") + expIDs := make(map[string]struct{}, len(ids)) + for _, id := range ids { + expIDs[id] = struct{}{} + } + am := &DefaultAccountManager{ Store: store, geo: geo, @@ -217,6 +237,10 @@ func BuildManager( permissionsManager: permissionsManager, loginFilter: newLoginFilter(), disableDefaultPolicy: disableDefaultPolicy, + holder: types.NewHolder(), + + expNewNetworkMap: newNetworkMapBuilder, + expNewNetworkMapAIDs: expIDs, } am.startWarmup(ctx) @@ -395,6 +419,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } go am.UpdateAccountPeers(ctx, accountID) } @@ -1477,6 +1504,10 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } if removedGroupAffectsPeers || newGroupsAffectsPeers { + if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil { + return err + } + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) } @@ -2129,6 +2160,11 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us } if updateNetworkMap { + peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + am.updatePeerInNetworkMapCache(peer.AccountID, peer) am.BufferUpdateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/account/manager.go b/management/server/account/manager.go index fe9fb25c6..db377865a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -128,4 +128,5 @@ type Manager interface { GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) SetEphemeralManager(em ephemeral.Manager) AllowSync(string, uint64) bool + RecalculateNetworkMapCache(ctx context.Context, accountId string) error } diff --git a/management/server/account_test.go b/management/server/account_test.go index 07d2f2383..200ba6b98 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1154,7 +1154,16 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } +func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_SaveGroup(t) +} + func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { + testAccountManager_NetworkUpdates_SaveGroup(t) +} + +func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) group := types.Group{ @@ -1205,7 +1214,16 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeletePolicy(t) +} + func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { + testAccountManager_NetworkUpdates_DeletePolicy(t) +} + +func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { manager, account, peer1, _, _ := setupNetworkMapTest(t) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1239,7 +1257,16 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_SavePolicy(t) +} + func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { + testAccountManager_NetworkUpdates_SavePolicy(t) +} + +func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { manager, account, peer1, peer2, _ := setupNetworkMapTest(t) group := types.Group{ @@ -1288,7 +1315,16 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeletePeer(t) +} + func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { + testAccountManager_NetworkUpdates_DeletePeer(t) +} + +func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { manager, account, peer1, _, peer3 := setupNetworkMapTest(t) group := types.Group{ @@ -1341,7 +1377,16 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeleteGroup(t) +} + func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { + testAccountManager_NetworkUpdates_DeleteGroup(t) +} + +func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1377,6 +1422,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } + for drained := false; !drained; { + select { + case <-updMsg: + default: + drained = true + } + } + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -1736,7 +1789,9 @@ func TestAccount_Copy(t *testing.T) { Address: "172.12.6.1/24", }, }, + NetworkMapCache: &types.NetworkMapBuilder{}, } + account.InitOnce() err := hasNilField(account) if err != nil { t.Fatal(err) diff --git a/management/server/dns.go b/management/server/dns.go index e5166ce47..decc5175d 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -117,6 +117,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/group.go b/management/server/group.go index a29c28892..3cf9290a2 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -114,6 +114,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -182,6 +185,9 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -250,6 +256,9 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -318,6 +327,9 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -481,6 +493,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -519,6 +534,9 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -547,6 +565,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -585,6 +606,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 12b59b691..0a5236cb3 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -7,8 +7,10 @@ import ( "net" "net/netip" "os" + "strconv" "strings" "sync" + "sync/atomic" "time" pb "github.com/golang/protobuf/proto" // nolint @@ -44,6 +46,9 @@ import ( const ( envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS" envBlockPeers = "NB_BLOCK_SAME_PEERS" + envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS" + + defaultSyncLim = 1000 ) // GRPCServer an instance of a Management gRPC API server @@ -63,6 +68,9 @@ type GRPCServer struct { logBlockedPeers bool blockPeersWithSameConfig bool integratedPeerValidator integrated_validator.IntegratedValidator + + syncSem atomic.Int32 + syncLim int32 } // NewServer creates a new Management server @@ -96,6 +104,17 @@ func NewServer( logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true" blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true" + syncLim := int32(defaultSyncLim) + if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" { + syncLimParsed, err := strconv.Atoi(syncLimStr) + if err != nil { + log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim) + } else { + //nolint:gosec + syncLim = int32(syncLimParsed) + } + } + return &GRPCServer{ wgKey: key, // peerKey -> event channel @@ -110,6 +129,8 @@ func NewServer( logBlockedPeers: logBlockedPeers, blockPeersWithSameConfig: blockPeersWithSameConfig, integratedPeerValidator: integratedPeerValidator, + + syncLim: syncLim, }, nil } @@ -151,6 +172,11 @@ func getRealIP(ctx context.Context) net.IP { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { + if s.syncSem.Load() >= s.syncLim { + return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later") + } + s.syncSem.Add(1) + reqStart := time.Now() ctx := srv.Context() @@ -158,6 +184,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi syncReq := &proto.SyncRequest{} peerKey, err := s.parseRequest(ctx, req, syncReq) if err != nil { + s.syncSem.Add(-1) return err } realIP := getRealIP(ctx) @@ -172,6 +199,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) } if s.blockPeersWithSameConfig { + s.syncSem.Add(-1) return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn) } } @@ -183,27 +211,34 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) - unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) - defer func() { - if unlock != nil { - unlock() - } - }() - accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) if err != nil { // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN") log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String()) if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound { + s.syncSem.Add(-1) return status.Errorf(codes.PermissionDenied, "peer is not registered") } + s.syncSem.Add(-1) return err } + log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) + // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + start := time.Now() + unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) + defer func() { + if unlock != nil { + unlock() + } + }() + log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start)) + log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart)) + log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) if syncReq.GetMeta() == nil { @@ -213,21 +248,32 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) return mapError(ctx, err) } + log.WithContext(ctx).Debugf("Sync: SyncAndMarkPeer since start %v", time.Since(reqStart)) + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) return err } + log.WithContext(ctx).Debugf("Sync: sendInitialSync since start %v", time.Since(reqStart)) updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) + log.WithContext(ctx).Debugf("Sync: CreateChannel since start %v", time.Since(reqStart)) + s.ephemeralManager.OnPeerConnected(ctx, peer) + log.WithContext(ctx).Debugf("Sync: OnPeerConnected since start %v", time.Since(reqStart)) + s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) + log.WithContext(ctx).Debugf("Sync: SetupRefresh since start %v", time.Since(reqStart)) + if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) } @@ -237,6 +283,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart)) + s.syncSem.Add(-1) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } @@ -509,10 +557,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) + defer func() { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) } + took := time.Since(reqStart) + if took > 7*time.Second { + log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart)) + } }() if loginReq.GetMeta() == nil { @@ -546,9 +600,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, mapError(ctx, err) } + log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart)) + // if the login request contains setup key then it is a registration request if loginReq.GetSetupKey() != "" { s.ephemeralManager.OnPeerDisconnected(ctx, peer) + log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart)) } loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) @@ -557,6 +614,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, status.Errorf(codes.Internal, "failed logging in peer") } + log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart)) + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) if err != nil { log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) @@ -716,6 +775,11 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("toSyncResponse: took %s", time.Since(start)) + }() + response := &proto.SyncResponse{ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), NetworkMap: &proto.NetworkMap{ @@ -780,6 +844,11 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("sendInitialSync: took %s", time.Since(start)) + }() + var err error var turnToken *Token @@ -822,10 +891,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "error handling request") } + sendStart := time.Now() err = srv.Send(&proto.EncryptedMessage{ WgPubKey: s.wgKey.PublicKey().String(), Body: encryptedResp, }) + log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart)) if err != nil { log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err) diff --git a/management/server/holder.go b/management/server/holder.go new file mode 100644 index 000000000..e8a26e1d0 --- /dev/null +++ b/management/server/holder.go @@ -0,0 +1,39 @@ +package server + +import ( + "github.com/netbirdio/netbird/management/server/types" +) + +func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) { + a := am.holder.GetAccount(account.Id) + if a == nil { + am.holder.AddAccount(account) + return + } + account.NetworkMapCache = a.NetworkMapCache + if account.NetworkMapCache == nil { + return + } + account.NetworkMapCache.UpdateAccountPointer(account) + am.holder.AddAccount(account) +} + +func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account { + return am.holder.GetAccount(accountID) +} + +func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account { + a := am.holder.GetAccount(accountID) + if a != nil { + return a + } + account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure) + if err != nil { + return nil + } + return account +} + +func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) { + am.holder.AddAccount(account) +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index e87043f26..8baffa58b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -125,9 +125,10 @@ type MockAccountManager struct { UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) - AllowSyncFunc func(string, uint64) bool - UpdateAccountPeersFunc func(ctx context.Context, accountID string) - BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + AllowSyncFunc func(string, uint64) bool + UpdateAccountPeersFunc func(ctx context.Context, accountID string) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error } func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { @@ -986,3 +987,10 @@ func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { } return true } + +func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error { + if am.RecalculateNetworkMapCacheFunc != nil { + return am.RecalculateNetworkMapCacheFunc(ctx, accountID) + } + return nil +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index f278e1761..ee77a65bb 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -83,6 +83,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } @@ -134,6 +137,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -177,6 +183,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/networkmap.go b/management/server/networkmap.go new file mode 100644 index 000000000..2a0627643 --- /dev/null +++ b/management/server/networkmap.go @@ -0,0 +1,80 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" +) + +func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { + am.enrichAccountFromHolder(account) + account.InitNetworkMapBuilderIfNeeded(validatedPeers) +} + +func (am *DefaultAccountManager) getPeerNetworkMapExp( + ctx context.Context, + accountId string, + peerId string, + validatedPeers map[string]struct{}, + customZone nbdns.CustomZone, + metrics *telemetry.AccountManagerMetrics, +) *types.NetworkMap { + account := am.getAccountFromHolderOrInit(accountId) + if account == nil { + log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) + return &types.NetworkMap{ + Network: &types.Network{}, + } + } + return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) +} + +func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error { + am.enrichAccountFromHolder(account) + return account.OnPeerAddedUpdNetworkMapCache(peerId) +} + +func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { + am.enrichAccountFromHolder(account) + return account.OnPeerDeletedUpdNetworkMapCache(peerId) +} + +func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { + account := am.getAccountFromHolder(accountId) + if account == nil { + return + } + account.UpdatePeerInNetworkMapCache(peer) +} + +func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { + account.RecalculateNetworkMapCache(validatedPeers) + am.updateAccountInHolder(account) +} + +func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { + if am.experimentalNetworkMap(accountId) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + return err + } + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) + return err + } + am.recalculateNetworkMapCache(account, validatedPeers) + } + return nil +} + +func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool { + _, ok := am.expNewNetworkMapAIDs[accountId] + return am.expNewNetworkMap || ok +} diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index b6706ca45..0e6d1631b 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -177,6 +177,9 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 66484d120..b740610c2 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -157,6 +157,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -257,6 +260,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -331,6 +337,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 82cac424a..89ac419fd 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -119,6 +119,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) + if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -183,6 +186,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) + if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -217,6 +223,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo event() + if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/peer.go b/management/server/peer.go index 30b7073ef..80ab7fc69 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -145,6 +145,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if expired { + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. am.BufferUpdateAccountPeers(ctx, accountID) @@ -321,6 +324,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } + if peerLabelChanged || requiresPeerUpdates { am.UpdateAccountPeers(ctx, accountID) } else if sshChanged { @@ -381,6 +388,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } + if am.experimentalNetworkMap(accountID) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } + + if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) + } + + } + if userID != activity.SystemInitiator { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -417,7 +436,13 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin return nil, err } - networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + var networkMap *types.NetworkMap + + if am.experimentalNetworkMap(peer.AccountID) { + networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -690,6 +715,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + if am.experimentalNetworkMap(accountID) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, err + } + + if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) + } + } + am.BufferUpdateAccountPeers(ctx, accountID) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) @@ -776,6 +812,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } am.BufferUpdateAccountPeers(ctx, accountID) } @@ -783,6 +822,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("handlePeerNotFound: took %s", time.Since(start)) + }() if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. @@ -804,6 +847,11 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("LoginPeer: took %s", time.Since(start)) + }() + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { return am.handlePeerLoginNotFound(ctx, login, err) @@ -831,6 +879,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } + startTransaction := time.Now() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { @@ -900,8 +949,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } + log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) + if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } + startBuffer := time.Now() am.BufferUpdateAccountPeers(ctx, accountID) + log.WithContext(ctx).Debugf("LoginPeer: BufferUpdateAccountPeers took %v", time.Since(startBuffer)) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) @@ -909,6 +965,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer // getPeerPostureChecks returns the posture checks for the peer. func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("getPostureChecks: took %s", time.Since(start)) + }() + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err @@ -1014,9 +1075,17 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return peer, emptyMap, nil, nil } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err + var ( + account *types.Account + err error + ) + if am.experimentalNetworkMap(accountID) { + account = am.getAccountFromHolderOrInit(accountID) + } else { + account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, err + } } approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) @@ -1024,10 +1093,12 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } + startPosture := time.Now() postureChecks, err := am.getPeerPostureChecks(account, peer.ID) if err != nil { return nil, nil, nil, err } + log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) @@ -1037,7 +1108,13 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) + var networkMap *types.NetworkMap + + if am.experimentalNetworkMap(accountID) { + networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) + } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -1167,11 +1244,18 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) - return + var ( + account *types.Account + err error + ) + if am.experimentalNetworkMap(accountID) { + account = am.getAccountFromHolderOrInit(accountID) + } else { + account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) + return + } } globalStart := time.Now() @@ -1204,6 +1288,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() + if am.experimentalNetworkMap(accountID) { + am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) + } + proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) @@ -1241,7 +1329,13 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start)) start = time.Now() - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + var remotePeerNetworkMap *types.NetworkMap + + if am.experimentalNetworkMap(accountID) { + remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + } am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start)) start = time.Now() @@ -1257,7 +1351,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) - am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) + am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) }(peer) } @@ -1351,7 +1445,13 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + var remotePeerNetworkMap *types.NetworkMap + + if am.experimentalNetworkMap(accountId) { + remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -1368,7 +1468,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) } // getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. @@ -1511,6 +1611,10 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p // getPeerGroupIDs returns the IDs of the groups that the peer is part of. func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("getPeerGroupIDs: took %s", time.Since(start)) + }() return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) } @@ -1580,7 +1684,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto }, }, }, - NetworkMap: &types.NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 3b2ab87fc..e151f5abb 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -168,6 +168,15 @@ func TestPeer_SessionExpired(t *testing.T) { } func TestAccountManager_GetNetworkMap(t *testing.T) { + testGetNetworkMapGeneral(t) +} + +func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testGetNetworkMapGeneral(t) +} + +func testGetNetworkMapGeneral(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -1003,7 +1012,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } } +func TestUpdateAccountPeers_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testUpdateAccountPeers(t) +} + func TestUpdateAccountPeers(t *testing.T) { + testUpdateAccountPeers(t) +} + +func testUpdateAccountPeers(t *testing.T) { testCases := []struct { name string peers int @@ -1043,8 +1061,8 @@ func TestUpdateAccountPeers(t *testing.T) { for _, channel := range peerChannels { update := <-channel assert.Nil(t, update.Update.NetbirdConfig) - assert.Equal(t, tc.peers, len(update.NetworkMap.Peers)) - assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules)) + assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers)) + assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules)) } }) } @@ -1548,6 +1566,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { } func Test_LoginPeer(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } diff --git a/management/server/policy.go b/management/server/policy.go index 9e4b3f73a..ff02d46aa 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -77,6 +77,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } @@ -120,6 +123,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 943f2a970..f457b994b 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -80,6 +80,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/route.go b/management/server/route.go index 4510426bb..05f7acf9e 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -192,6 +192,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } @@ -246,6 +249,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) if oldRouteAffectsPeers || newRouteAffectsPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -289,6 +295,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index 2b2896572..f16b609f8 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -5,6 +5,9 @@ package settings import ( "context" "fmt" + "time" + + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/extra_settings" @@ -45,6 +48,11 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager { } func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start)) + }() + if userID != activity.SystemInitiator { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) if err != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 4201b68f6..d83d160c3 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2,6 +2,7 @@ package store import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -15,6 +16,8 @@ import ( "sync" "time" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -46,6 +49,11 @@ const ( accountAndIDsQueryCondition = "account_id = ? AND id IN ?" accountIDCondition = "account_id = ?" peerNotFoundFMT = "peer %s not found" + + pgMaxConnections = 30 + pgMinConnections = 1 + pgMaxConnLifetime = 60 * time.Minute + pgHealthCheckPeriod = 1 * time.Minute ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -55,6 +63,7 @@ type SqlStore struct { metrics telemetry.AppMetrics installationPK int storeEngine types.Engine + pool *pgxpool.Pool } type installation struct { @@ -307,6 +316,10 @@ func (s *SqlStore) GetInstallationID() string { } func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("SavePeer: took %s", time.Since(start)) + }() // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID @@ -778,6 +791,13 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types. } func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { + if s.pool != nil { + return s.getAccountPgx(ctx, accountID) + } + return s.getAccountGorm(ctx, accountID) +} + +func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -788,9 +808,19 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc var account types.Account result := s.db.Model(&account). - Omit("GroupsG"). - Preload("UsersG.PATsG"). // have to be specifies as this is nester reference - Preload(clause.Associations). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). Take(&account, idQueryCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) @@ -800,70 +830,1147 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc return nil, status.NewGetAccountFromStoreError(result.Error) } - // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us - for i, policy := range account.Policies { - var rules []*types.PolicyRule - err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error - if err != nil { - return nil, status.Errorf(status.NotFound, "rule not found") - } - account.Policies[i].Rules = rules - } - account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { - account.SetupKeys[key.Key] = key.Copy() + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key } account.SetupKeysG = nil account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) for _, peer := range account.PeersG { - account.Peers[peer.ID] = peer.Copy() + account.Peers[peer.ID] = &peer } account.PeersG = nil - account.Users = make(map[string]*types.User, len(account.UsersG)) for _, user := range account.UsersG { user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) for _, pat := range user.PATsG { - user.PATs[pat.ID] = pat.Copy() + pat.UserID = "" + user.PATs[pat.ID] = &pat } - account.Users[user.Id] = user.Copy() + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil } account.UsersG = nil - account.Groups = make(map[string]*types.Group, len(account.GroupsG)) for _, group := range account.GroupsG { - account.Groups[group.ID] = group.Copy() + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group } account.GroupsG = nil - var groupPeers []types.GroupPeer - s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). - Find(&groupPeers) - for _, groupPeer := range groupPeers { - if group, ok := account.Groups[groupPeer.GroupID]; ok { - group.Peers = append(group.Peers, groupPeer.PeerID) - } else { - log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + account.InitOnce() + return &account, nil +} + +func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e } } - account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) - for _, route := range account.RoutesG { - account.Routes[route.ID] = route.Copy() + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route } - account.RoutesG = nil account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) - for _, ns := range account.NameServerGroupsG { - account.NameServerGroups[ns.ID] = ns.Copy() + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil account.NameServerGroupsG = nil + return account, nil +} + +func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) { + var account types.Account + account.Network = &types.Network{} + const accountQuery = ` + SELECT + id, created_by, created_at, domain, domain_category, is_domain_primary_account, + -- Embedded Network + network_identifier, network_net, network_dns, network_serial, + -- Embedded DNSSettings + dns_settings_disabled_management_groups, + -- Embedded Settings + settings_peer_login_expiration_enabled, settings_peer_login_expiration, + settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration, + settings_regular_users_view_blocked, settings_groups_propagation_enabled, + settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, + settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, + settings_lazy_connection_enabled, + -- Embedded ExtraSettings + settings_extra_peer_approval_enabled, settings_extra_user_approval_required, + settings_extra_integrated_validator, settings_extra_integrated_validator_groups + FROM accounts WHERE id = $1` + + var ( + sPeerLoginExpirationEnabled sql.NullBool + sPeerLoginExpiration sql.NullInt64 + sPeerInactivityExpirationEnabled sql.NullBool + sPeerInactivityExpiration sql.NullInt64 + sRegularUsersViewBlocked sql.NullBool + sGroupsPropagationEnabled sql.NullBool + sJWTGroupsEnabled sql.NullBool + sJWTGroupsClaimName sql.NullString + sJWTAllowGroups sql.NullString + sRoutingPeerDNSResolutionEnabled sql.NullBool + sDNSDomain sql.NullString + sNetworkRange sql.NullString + sLazyConnectionEnabled sql.NullBool + sExtraPeerApprovalEnabled sql.NullBool + sExtraUserApprovalRequired sql.NullBool + sExtraIntegratedValidator sql.NullString + sExtraIntegratedValidatorGroups sql.NullString + networkNet sql.NullString + dnsSettingsDisabledGroups sql.NullString + networkIdentifier sql.NullString + networkDns sql.NullString + networkSerial sql.NullInt64 + createdAt sql.NullTime + ) + err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( + &account.Id, &account.CreatedBy, &createdAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, + &networkIdentifier, &networkNet, &networkDns, &networkSerial, + &dnsSettingsDisabledGroups, + &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, + &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, + &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, + &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, + &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, + &sLazyConnectionEnabled, + &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, + &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, + ) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(err) + } + + account.Settings = &types.Settings{Extra: &types.ExtraSettings{}} + if networkNet.Valid { + _ = json.Unmarshal([]byte(networkNet.String), &account.Network.Net) + } + if createdAt.Valid { + account.CreatedAt = createdAt.Time + } + if dnsSettingsDisabledGroups.Valid { + _ = json.Unmarshal([]byte(dnsSettingsDisabledGroups.String), &account.DNSSettings.DisabledManagementGroups) + } + if networkIdentifier.Valid { + account.Network.Identifier = networkIdentifier.String + } + if networkDns.Valid { + account.Network.Dns = networkDns.String + } + if networkSerial.Valid { + account.Network.Serial = uint64(networkSerial.Int64) + } + if sPeerLoginExpirationEnabled.Valid { + account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool + } + if sPeerLoginExpiration.Valid { + account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64) + } + if sPeerInactivityExpirationEnabled.Valid { + account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool + } + if sPeerInactivityExpiration.Valid { + account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64) + } + if sRegularUsersViewBlocked.Valid { + account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool + } + if sGroupsPropagationEnabled.Valid { + account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool + } + if sJWTGroupsEnabled.Valid { + account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool + } + if sJWTGroupsClaimName.Valid { + account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String + } + if sRoutingPeerDNSResolutionEnabled.Valid { + account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool + } + if sDNSDomain.Valid { + account.Settings.DNSDomain = sDNSDomain.String + } + if sLazyConnectionEnabled.Valid { + account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool + } + if sJWTAllowGroups.Valid { + _ = json.Unmarshal([]byte(sJWTAllowGroups.String), &account.Settings.JWTAllowGroups) + } + if sNetworkRange.Valid { + _ = json.Unmarshal([]byte(sNetworkRange.String), &account.Settings.NetworkRange) + } + + if sExtraPeerApprovalEnabled.Valid { + account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool + } + if sExtraUserApprovalRequired.Valid { + account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool + } + if sExtraIntegratedValidator.Valid { + account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String + } + if sExtraIntegratedValidatorGroups.Valid { + _ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups) + } + account.InitOnce() return &account, nil } +func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types.SetupKey, error) { + const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, + revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { + var sk types.SetupKey + var autoGroups []byte + var skCreatedAt, expiresAt, updatedAt, lastUsed sql.NullTime + var revoked, ephemeral, allowExtraDNSLabels sql.NullBool + var usedTimes, usageLimit sql.NullInt64 + + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &skCreatedAt, + &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) + + if err == nil { + if expiresAt.Valid { + sk.ExpiresAt = &expiresAt.Time + } + if skCreatedAt.Valid { + sk.CreatedAt = skCreatedAt.Time + } + if updatedAt.Valid { + sk.UpdatedAt = updatedAt.Time + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + } + if lastUsed.Valid { + sk.LastUsed = &lastUsed.Time + } + if revoked.Valid { + sk.Revoked = revoked.Bool + } + if usedTimes.Valid { + sk.UsedTimes = int(usedTimes.Int64) + } + if usageLimit.Valid { + sk.UsageLimit = int(usageLimit.Int64) + } + if ephemeral.Valid { + sk.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} + } + } + return sk, err + }) + if err != nil { + return nil, err + } + return keys, nil +} + +func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) { + const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, + inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, + meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, + meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, + meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, + peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, + location_geo_name_id FROM peers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { + var p nbpeer.Peer + p.Status = &nbpeer.PeerStatus{} + var ( + lastLogin, createdAt sql.NullTime + sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + peerStatusLastSeen sql.NullTime + peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool + ip, extraDNS, netAddr, env, flags, files, connIP []byte + metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString + metaOS, metaOSVersion, metaWtVersion, metaUIVersion, metaKernelVersion sql.NullString + metaSystemSerialNumber, metaSystemProductName, metaSystemManufacturer sql.NullString + locationCountryCode, locationCityName sql.NullString + locationGeoNameID sql.NullInt64 + ) + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, + &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &createdAt, &ephemeral, &extraDNS, + &allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform, + &metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr, + &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, + &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, + &locationCountryCode, &locationCityName, &locationGeoNameID) + + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if createdAt.Valid { + p.CreatedAt = createdAt.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + if metaHostname.Valid { + p.Meta.Hostname = metaHostname.String + } + if metaGoOS.Valid { + p.Meta.GoOS = metaGoOS.String + } + if metaKernel.Valid { + p.Meta.Kernel = metaKernel.String + } + if metaCore.Valid { + p.Meta.Core = metaCore.String + } + if metaPlatform.Valid { + p.Meta.Platform = metaPlatform.String + } + if metaOS.Valid { + p.Meta.OS = metaOS.String + } + if metaOSVersion.Valid { + p.Meta.OSVersion = metaOSVersion.String + } + if metaWtVersion.Valid { + p.Meta.WtVersion = metaWtVersion.String + } + if metaUIVersion.Valid { + p.Meta.UIVersion = metaUIVersion.String + } + if metaKernelVersion.Valid { + p.Meta.KernelVersion = metaKernelVersion.String + } + if metaSystemSerialNumber.Valid { + p.Meta.SystemSerialNumber = metaSystemSerialNumber.String + } + if metaSystemProductName.Valid { + p.Meta.SystemProductName = metaSystemProductName.String + } + if metaSystemManufacturer.Valid { + p.Meta.SystemManufacturer = metaSystemManufacturer.String + } + if locationCountryCode.Valid { + p.Location.CountryCode = locationCountryCode.String + } + if locationCityName.Valid { + p.Location.CityName = locationCityName.String + } + if locationGeoNameID.Valid { + p.Location.GeoNameID = uint(locationGeoNameID.Int64) + } + if ip != nil { + _ = json.Unmarshal(ip, &p.IP) + } + if extraDNS != nil { + _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) + } + if netAddr != nil { + _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) + } + if env != nil { + _ = json.Unmarshal(env, &p.Meta.Environment) + } + if flags != nil { + _ = json.Unmarshal(flags, &p.Meta.Flags) + } + if files != nil { + _ = json.Unmarshal(files, &p.Meta.Files) + } + if connIP != nil { + _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) + } + } + return p, err + }) + if err != nil { + return nil, err + } + return peers, nil +} + +func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { + var u types.User + var autoGroups []byte + var lastLogin, createdAt sql.NullTime + var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil { + if lastLogin.Valid { + u.LastLogin = &lastLogin.Time + } + if createdAt.Valid { + u.CreatedAt = createdAt.Time + } + if isServiceUser.Valid { + u.IsServiceUser = isServiceUser.Bool + } + if nonDeletable.Valid { + u.NonDeletable = nonDeletable.Bool + } + if blocked.Valid { + u.Blocked = blocked.Bool + } + if pendingApproval.Valid { + u.PendingApproval = pendingApproval.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} + } + } + return u, err + }) + if err != nil { + return nil, err + } + return users, nil +} + +func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) { + const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { + var g types.Group + var resources []byte + var refID sql.NullInt64 + var refType sql.NullString + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) + if err == nil { + if refID.Valid { + g.IntegrationReference.ID = int(refID.Int64) + } + if refType.Valid { + g.IntegrationReference.IntegrationType = refType.String + } + if resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} + } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} + } + return &g, err + }) + if err != nil { + return nil, err + } + return groups, nil +} + +func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) { + const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { + var p types.Policy + var checks []byte + var enabled sql.NullBool + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + if err == nil { + if enabled.Valid { + p.Enabled = enabled.Bool + } + if checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } + } + return &p, err + }) + if err != nil { + return nil, err + } + return policies, nil +} + +func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) { + const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { + var r route.Route + var network, domains, peerGroups, groups, accessGroups []byte + var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) + if err == nil { + if keepRoute.Valid { + r.KeepRoute = keepRoute.Bool + } + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if skipAutoApply.Valid { + r.SkipAutoApply = skipAutoApply.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if network != nil { + _ = json.Unmarshal(network, &r.Network) + } + if domains != nil { + _ = json.Unmarshal(domains, &r.Domains) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + if groups != nil { + _ = json.Unmarshal(groups, &r.Groups) + } + if accessGroups != nil { + _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + return routes, nil +} + +func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) { + const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { + var n nbdns.NameServerGroup + var ns, groups, domains []byte + var primary, enabled, searchDomainsEnabled sql.NullBool + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) + if err == nil { + if primary.Valid { + n.Primary = primary.Bool + } + if enabled.Valid { + n.Enabled = enabled.Bool + } + if searchDomainsEnabled.Valid { + n.SearchDomainsEnabled = searchDomainsEnabled.Bool + } + if ns != nil { + _ = json.Unmarshal(ns, &n.NameServers) + } else { + n.NameServers = []nbdns.NameServer{} + } + if groups != nil { + _ = json.Unmarshal(groups, &n.Groups) + } else { + n.Groups = []string{} + } + if domains != nil { + _ = json.Unmarshal(domains, &n.Domains) + } else { + n.Domains = []string{} + } + } + return n, err + }) + if err != nil { + return nil, err + } + return nsgs, nil +} + +func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { + const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { + var c posture.Checks + var checksDef []byte + err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) + if err == nil && checksDef != nil { + _ = json.Unmarshal(checksDef, &c.Checks) + } + return &c, err + }) + if err != nil { + return nil, err + } + return checks, nil +} + +func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) { + const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) + if err != nil { + return nil, err + } + result := make([]*networkTypes.Network, len(networks)) + for i := range networks { + result[i] = &networks[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) { + const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { + var r routerTypes.NetworkRouter + var peerGroups []byte + var masquerade, enabled sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + if err == nil { + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*routerTypes.NetworkRouter, len(routers)) + for i := range routers { + result[i] = &routers[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) { + const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { + var r resourceTypes.NetworkResource + var prefix []byte + var enabled sql.NullBool + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*resourceTypes.NetworkResource, len(resources)) + for i := range resources { + result[i] = &resources[i] + } + return result, nil +} + +func (s *SqlStore) getAccountOnboarding(ctx context.Context, accountID string, account *types.Account) error { + const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + var onboardingFlowPending, signupFormPending sql.NullBool + var createdAt, updatedAt sql.NullTime + err := s.pool.QueryRow(ctx, query, accountID).Scan( + &account.Onboarding.AccountID, + &onboardingFlowPending, + &signupFormPending, + &createdAt, + &updatedAt, + ) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return err + } + if createdAt.Valid { + account.Onboarding.CreatedAt = createdAt.Time + } + if updatedAt.Valid { + account.Onboarding.UpdatedAt = updatedAt.Time + } + if onboardingFlowPending.Valid { + account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool + } + if signupFormPending.Valid { + account.Onboarding.SignupFormPending = signupFormPending.Bool + } + return nil +} + +func (s *SqlStore) getPersonalAccessTokens(ctx context.Context, userIDs []string) ([]types.PersonalAccessToken, error) { + if len(userIDs) == 0 { + return nil, nil + } + const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, userIDs) + if err != nil { + return nil, err + } + pats, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + var expirationDate, lastUsed, createdAt sql.NullTime + err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &createdAt, &lastUsed) + if err == nil { + if expirationDate.Valid { + pat.ExpirationDate = &expirationDate.Time + } + if createdAt.Valid { + pat.CreatedAt = createdAt.Time + } + if lastUsed.Valid { + pat.LastUsed = &lastUsed.Time + } + } + return pat, err + }) + if err != nil { + return nil, err + } + return pats, nil +} + +func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*types.PolicyRule, error) { + if len(policyIDs) == 0 { + return nil, nil + } + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, policyIDs) + if err != nil { + return nil, err + } + rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { + var r types.PolicyRule + var dest, destRes, sources, sourceRes, ports, portRanges []byte + var enabled, bidirectional sql.NullBool + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if bidirectional.Valid { + r.Bidirectional = bidirectional.Bool + } + if dest != nil { + _ = json.Unmarshal(dest, &r.Destinations) + } + if destRes != nil { + _ = json.Unmarshal(destRes, &r.DestinationResource) + } + if sources != nil { + _ = json.Unmarshal(sources, &r.Sources) + } + if sourceRes != nil { + _ = json.Unmarshal(sourceRes, &r.SourceResource) + } + if ports != nil { + _ = json.Unmarshal(ports, &r.Ports) + } + if portRanges != nil { + _ = json.Unmarshal(portRanges, &r.PortRanges) + } + } + return &r, err + }) + if err != nil { + return nil, err + } + return rules, nil +} + +func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]types.GroupPeer, error) { + if len(groupIDs) == 0 { + return nil, nil + } + const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, groupIDs) + if err != nil { + return nil, err + } + groupPeers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) + if err != nil { + return nil, err + } + return groupPeers, nil +} + func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { var user types.User result := s.db.Select("account_id").Take(&user, idQueryCondition, userID) @@ -1054,6 +2161,10 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("GetAccountNetwork: took %s", time.Since(start)) + }() ctx, cancel := getDebuggingCtx(ctx) defer cancel() @@ -1095,6 +2206,11 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking } func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("getAccountSettings: took %s", time.Since(start)) + }() + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -1203,8 +2319,41 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe if err != nil { return nil, err } + pool, err := connectToPgDb(context.Background(), dsn) + if err != nil { + return nil, err + } + store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) + if err != nil { + pool.Close() + return nil, err + } + store.pool = pool + return store, nil +} - return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) +func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = pgMaxConnections + config.MinConns = pgMinConnections + config.MaxConnLifetime = pgMaxConnLifetime + config.HealthCheckPeriod = pgHealthCheckPeriod + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + return pool, nil } // NewMysqlStore creates a new MySQL store. @@ -1273,7 +2422,7 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data // NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB. func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { - store, err := NewPostgresqlStore(ctx, dsn, metrics, false) + store, err := NewPostgresqlStoreForTests(ctx, dsn, metrics, false) if err != nil { return nil, err } @@ -1293,6 +2442,50 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, return store, nil } +// used for tests only +func NewPostgresqlStoreForTests(ctx context.Context, dsn string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) { + db, err := gorm.Open(postgres.Open(dsn), getGormConfig()) + if err != nil { + return nil, err + } + pool, err := connectToPgDbForTests(context.Background(), dsn) + if err != nil { + return nil, err + } + store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) + if err != nil { + pool.Close() + return nil, err + } + store.pool = pool + return store, nil +} + +// used for tests only +func connectToPgDbForTests(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 5 + config.MinConns = 1 + config.MaxConnLifetime = 30 * time.Second + config.HealthCheckPeriod = 10 * time.Second + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + return pool, nil +} + // NewMysqlStoreFromSqlStore restores a store from SqlStore and stores MySQL DB. func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { store, err := NewMysqlStore(ctx, dsn, metrics, false) diff --git a/management/server/store/sql_store_get_account_test.go b/management/server/store/sql_store_get_account_test.go new file mode 100644 index 000000000..8ff04d68a --- /dev/null +++ b/management/server/store/sql_store_get_account_test.go @@ -0,0 +1,1089 @@ +package store + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/integration_reference" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads +// all fields and nested objects from the database, including deeply nested structures. +func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { + if testing.Short() { + t.Skip("skipping comprehensive test in short mode") + } + + ctx := context.Background() + store, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + // Create comprehensive test data + accountID := "test-account-comprehensive" + userID1 := "user-1" + userID2 := "user-2" + peerID1 := "peer-1" + peerID2 := "peer-2" + peerID3 := "peer-3" + groupID1 := "group-1" + groupID2 := "group-2" + setupKeyID1 := "setup-key-1" + setupKeyID2 := "setup-key-2" + routeID1 := route.ID("route-1") + routeID2 := route.ID("route-2") + nsGroupID1 := "ns-group-1" + nsGroupID2 := "ns-group-2" + policyID1 := "policy-1" + policyID2 := "policy-2" + postureCheckID1 := "posture-check-1" + postureCheckID2 := "posture-check-2" + networkID1 := "network-1" + routerID1 := "router-1" + resourceID1 := "resource-1" + patID1 := "pat-1" + patID2 := "pat-2" + patID3 := "pat-3" + + now := time.Now().UTC().Truncate(time.Second) + lastLogin := now.Add(-24 * time.Hour) + patLastUsed := now.Add(-1 * time.Hour) + + // Build comprehensive account with all fields populated + account := &types.Account{ + Id: accountID, + CreatedBy: userID1, + CreatedAt: now, + Domain: "example.com", + DomainCategory: "business", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "test-network", + Net: net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }, + Dns: "test-dns", + Serial: 42, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"dns-group-1", "dns-group-2"}, + }, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour * 24 * 30, + GroupsPropagationEnabled: true, + JWTGroupsEnabled: true, + JWTGroupsClaimName: "groups", + JWTAllowGroups: []string{"allowed-group-1", "allowed-group-2"}, + RegularUsersViewBlocked: false, + Extra: &types.ExtraSettings{ + PeerApprovalEnabled: true, + IntegratedValidatorGroups: []string{"validator-1"}, + }, + }, + } + + // Create Setup Keys with all fields + setupKey1ExpiresAt := now.Add(30 * 24 * time.Hour) + setupKey1LastUsed := now.Add(-2 * time.Hour) + setupKey1 := &types.SetupKey{ + Id: setupKeyID1, + AccountID: accountID, + Key: "setup-key-secret-1", + Name: "Setup Key 1", + Type: types.SetupKeyReusable, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey1ExpiresAt, + Revoked: false, + UsedTimes: 5, + LastUsed: &setupKey1LastUsed, + AutoGroups: []string{groupID1, groupID2}, + UsageLimit: 100, + Ephemeral: false, + } + + setupKey2ExpiresAt := now.Add(7 * 24 * time.Hour) + setupKey2LastUsed := now.Add(-1 * time.Hour) + setupKey2 := &types.SetupKey{ + Id: setupKeyID2, + AccountID: accountID, + Key: "setup-key-secret-2", + Name: "Setup Key 2 (One-off)", + Type: types.SetupKeyOneOff, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey2ExpiresAt, + Revoked: true, + UsedTimes: 1, + LastUsed: &setupKey2LastUsed, + AutoGroups: []string{}, + UsageLimit: 1, + Ephemeral: true, + } + + account.SetupKeys = map[string]*types.SetupKey{ + setupKey1.Key: setupKey1, + setupKey2.Key: setupKey2, + } + + // Create Peers with comprehensive fields + peer1 := &nbpeer.Peer{ + ID: peerID1, + AccountID: accountID, + Key: "peer-key-1-AAAA", + Name: "Peer 1", + IP: net.ParseIP("100.64.0.1"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer1.example.com", + GoOS: "linux", + Kernel: "5.15.0", + Core: "x86_64", + Platform: "ubuntu", + OS: "Ubuntu 22.04", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + KernelVersion: "5.15.0-78-generic", + OSVersion: "22.04", + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.1.10/32"), Mac: "00:11:22:33:44:55"}, + {NetIP: netip.MustParsePrefix("10.0.0.5/32"), Mac: "00:11:22:33:44:66"}, + }, + SystemSerialNumber: "ABC123", + SystemProductName: "Server Model X", + SystemManufacturer: "Dell Inc.", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-5 * time.Minute), + Connected: true, + LoginExpired: false, + RequiresApproval: false, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("203.0.113.10"), + CountryCode: "US", + CityName: "San Francisco", + GeoNameID: 5391959, + }, + SSHEnabled: true, + SSHKey: "ssh-rsa AAAAB3NzaC1...", + UserID: userID1, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + DNSLabel: "peer1", + CreatedAt: now.Add(-30 * 24 * time.Hour), + Ephemeral: false, + } + + peer2 := &nbpeer.Peer{ + ID: peerID2, + AccountID: accountID, + Key: "peer-key-2-BBBB", + Name: "Peer 2", + IP: net.ParseIP("100.64.0.2"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer2.example.com", + GoOS: "darwin", + Kernel: "22.0.0", + Core: "arm64", + Platform: "darwin", + OS: "macOS Ventura", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-1 * time.Hour), + Connected: false, + LoginExpired: true, + RequiresApproval: true, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("198.51.100.20"), + CountryCode: "GB", + CityName: "London", + GeoNameID: 2643743, + }, + SSHEnabled: false, + UserID: userID2, + LoginExpirationEnabled: false, + InactivityExpirationEnabled: true, + DNSLabel: "peer2", + CreatedAt: now.Add(-15 * 24 * time.Hour), + Ephemeral: false, + } + + peer3 := &nbpeer.Peer{ + ID: peerID3, + AccountID: accountID, + Key: "peer-key-3-CCCC", + Name: "Peer 3 (Ephemeral)", + IP: net.ParseIP("100.64.0.3"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer3.example.com", + GoOS: "windows", + Platform: "windows", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-10 * time.Minute), + Connected: true, + }, + DNSLabel: "peer3", + CreatedAt: now.Add(-1 * time.Hour), + Ephemeral: true, + } + + account.Peers = map[string]*nbpeer.Peer{ + peerID1: peer1, + peerID2: peer2, + peerID3: peer3, + } + + // Create Users with PATs + pat1ExpirationDate := now.Add(90 * 24 * time.Hour) + pat1 := &types.PersonalAccessToken{ + ID: patID1, + Name: "PAT 1", + HashedToken: "hashed-token-1", + ExpirationDate: &pat1ExpirationDate, + CreatedAt: now.Add(-10 * 24 * time.Hour), + CreatedBy: userID1, + LastUsed: &patLastUsed, + } + + pat2ExpirationDate := now.Add(30 * 24 * time.Hour) + pat2 := &types.PersonalAccessToken{ + ID: patID2, + Name: "PAT 2", + HashedToken: "hashed-token-2", + ExpirationDate: &pat2ExpirationDate, + CreatedAt: now.Add(-5 * 24 * time.Hour), + CreatedBy: userID1, + } + + pat3ExpirationDate := now.Add(60 * 24 * time.Hour) + pat3 := &types.PersonalAccessToken{ + ID: patID3, + Name: "PAT 3", + HashedToken: "hashed-token-3", + ExpirationDate: &pat3ExpirationDate, + CreatedAt: now.Add(-2 * 24 * time.Hour), + CreatedBy: userID2, + } + + user1 := &types.User{ + Id: userID1, + AccountID: accountID, + Role: types.UserRoleOwner, + IsServiceUser: false, + NonDeletable: true, + AutoGroups: []string{groupID1}, + Issued: types.UserIssuedAPI, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 123, + IntegrationType: "azure_ad", + }, + CreatedAt: now.Add(-60 * 24 * time.Hour), + LastLogin: &lastLogin, + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID1: pat1, + patID2: pat2, + }, + } + + user2 := &types.User{ + Id: userID2, + AccountID: accountID, + Role: types.UserRoleAdmin, + IsServiceUser: true, + NonDeletable: false, + AutoGroups: []string{groupID2}, + Issued: types.UserIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 456, + IntegrationType: "google_workspace", + }, + CreatedAt: now.Add(-30 * 24 * time.Hour), + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID3: pat3, + }, + } + + account.Users = map[string]*types.User{ + userID1: user1, + userID2: user2, + } + + // Create Groups with peers and resources + group1 := &types.Group{ + ID: groupID1, + AccountID: accountID, + Name: "Group 1", + Issued: types.GroupIssuedAPI, + Peers: []string{peerID1, peerID2}, + Resources: []types.Resource{ + { + ID: "resource-1", + Type: types.ResourceTypeHost, + }, + }, + } + + group2 := &types.Group{ + ID: groupID2, + AccountID: accountID, + Name: "Group 2", + Issued: types.GroupIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 789, + IntegrationType: "okta", + }, + Peers: []string{peerID3}, + Resources: []types.Resource{}, + } + + account.Groups = map[string]*types.Group{ + groupID1: group1, + groupID2: group2, + } + + // Create Policies with Rules + policy1 := &types.Policy{ + ID: policyID1, + AccountID: accountID, + Name: "Policy 1", + Description: "Main access policy", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "rule-1", + PolicyID: policyID1, + Name: "Rule 1", + Description: "Allow access", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Ports: []string{}, + PortRanges: []types.RulePortRange{}, + Sources: []string{groupID1}, + Destinations: []string{groupID2}, + }, + { + ID: "rule-2", + PolicyID: policyID1, + Name: "Rule 2", + Description: "Block traffic on specific ports", + Enabled: true, + Action: types.PolicyTrafficActionDrop, + Bidirectional: false, + Protocol: types.PolicyRuleProtocolTCP, + Ports: []string{"22", "3389"}, + PortRanges: []types.RulePortRange{ + {Start: 8000, End: 8999}, + }, + Sources: []string{groupID2}, + Destinations: []string{groupID1}, + }, + }, + } + + policy2 := &types.Policy{ + ID: policyID2, + AccountID: accountID, + Name: "Policy 2", + Description: "Secondary policy", + Enabled: false, + Rules: []*types.PolicyRule{ + { + ID: "rule-3", + PolicyID: policyID2, + Name: "Rule 3", + Description: "UDP access", + Enabled: false, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolUDP, + Ports: []string{"53"}, + Sources: []string{groupID1}, + Destinations: []string{groupID1}, + }, + }, + } + + account.Policies = []*types.Policy{policy1, policy2} + + // Create Routes + route1 := &route.Route{ + ID: routeID1, + AccountID: accountID, + Network: netip.MustParsePrefix("10.0.0.0/24"), + NetworkType: route.IPv4Network, + Peer: peerID1, + PeerGroups: []string{}, + Description: "Route 1", + NetID: "net-id-1", + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{groupID1}, + AccessControlGroups: []string{groupID2}, + } + + route2 := &route.Route{ + ID: routeID2, + AccountID: accountID, + Network: netip.MustParsePrefix("192.168.1.0/24"), + NetworkType: route.IPv4Network, + Peer: "", + PeerGroups: []string{groupID2}, + Description: "Route 2 (High Availability)", + NetID: "net-id-2", + Masquerade: false, + Metric: 100, + Enabled: true, + Groups: []string{groupID1, groupID2}, + AccessControlGroups: []string{groupID1}, + } + + account.Routes = map[route.ID]*route.Route{ + routeID1: route1, + routeID2: route2, + } + + // Create NameServer Groups + nsGroup1 := &nbdns.NameServerGroup{ + ID: nsGroupID1, + AccountID: accountID, + Name: "NS Group 1", + Description: "Primary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + { + IP: netip.MustParseAddr("8.8.4.4"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{groupID1, groupID2}, + Domains: []string{"example.com", "test.com"}, + Enabled: true, + Primary: true, + SearchDomainsEnabled: true, + } + + nsGroup2 := &nbdns.NameServerGroup{ + ID: nsGroupID2, + AccountID: accountID, + Name: "NS Group 2", + Description: "Secondary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{}, + Domains: []string{}, + Enabled: false, + Primary: false, + SearchDomainsEnabled: false, + } + + account.NameServerGroups = map[string]*nbdns.NameServerGroup{ + nsGroupID1: nsGroup1, + nsGroupID2: nsGroup2, + } + + // Create Posture Checks + postureCheck1 := &posture.Checks{ + ID: postureCheckID1, + AccountID: accountID, + Name: "Posture Check 1", + Description: "OS version check", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.24.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "16.0", + }, + Darwin: &posture.MinVersionCheck{ + MinVersion: "22.0.0", + }, + }, + }, + } + + postureCheck2 := &posture.Checks{ + ID: postureCheckID2, + AccountID: accountID, + Name: "Posture Check 2", + Description: "Geo location check", + Checks: posture.ChecksDefinition{ + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "US", + CityName: "San Francisco", + }, + { + CountryCode: "GB", + CityName: "London", + }, + }, + Action: "allow", + }, + PeerNetworkRangeCheck: &posture.PeerNetworkRangeCheck{ + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + Action: "allow", + }, + }, + } + + account.PostureChecks = []*posture.Checks{postureCheck1, postureCheck2} + + // Create Networks + network1 := &networkTypes.Network{ + ID: networkID1, + AccountID: accountID, + Name: "Network 1", + Description: "Primary network", + } + + account.Networks = []*networkTypes.Network{network1} + + // Create Network Routers + router1 := &routerTypes.NetworkRouter{ + ID: routerID1, + AccountID: accountID, + NetworkID: networkID1, + Peer: peerID1, + PeerGroups: []string{}, + Masquerade: true, + Metric: 100, + } + + account.NetworkRouters = []*routerTypes.NetworkRouter{router1} + + // Create Network Resources + resource1 := &resourceTypes.NetworkResource{ + ID: resourceID1, + AccountID: accountID, + NetworkID: networkID1, + Name: "Resource 1", + Description: "Web server", + Prefix: netip.MustParsePrefix("192.168.1.100/32"), + Type: resourceTypes.Host, + } + + account.NetworkResources = []*resourceTypes.NetworkResource{resource1} + + // Create Onboarding + account.Onboarding = types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + SignupFormPending: false, + CreatedAt: now, + UpdatedAt: now, + } + + // Save the account to the database + err = store.SaveAccount(ctx, account) + require.NoError(t, err, "Failed to save comprehensive test account") + + // Retrieve the account from the database + retrievedAccount, err := store.GetAccount(ctx, accountID) + require.NoError(t, err, "Failed to retrieve account") + require.NotNil(t, retrievedAccount, "Retrieved account should not be nil") + + // ========== VALIDATE TOP-LEVEL FIELDS ========== + t.Run("TopLevelFields", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Id, "Account ID mismatch") + assert.Equal(t, userID1, retrievedAccount.CreatedBy, "CreatedBy mismatch") + assert.WithinDuration(t, now, retrievedAccount.CreatedAt, time.Second, "CreatedAt mismatch") + assert.Equal(t, "example.com", retrievedAccount.Domain, "Domain mismatch") + assert.Equal(t, "business", retrievedAccount.DomainCategory, "DomainCategory mismatch") + assert.True(t, retrievedAccount.IsDomainPrimaryAccount, "IsDomainPrimaryAccount should be true") + }) + + // ========== VALIDATE EMBEDDED NETWORK ========== + t.Run("EmbeddedNetwork", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Network, "Network should not be nil") + assert.Equal(t, "test-network", retrievedAccount.Network.Identifier, "Network Identifier mismatch") + assert.Equal(t, "test-dns", retrievedAccount.Network.Dns, "Network DNS mismatch") + assert.Equal(t, uint64(42), retrievedAccount.Network.Serial, "Network Serial mismatch") + + expectedIP := net.ParseIP("100.64.0.0") + assert.True(t, retrievedAccount.Network.Net.IP.Equal(expectedIP), "Network IP mismatch") + expectedMask := net.CIDRMask(10, 32) + assert.Equal(t, expectedMask, retrievedAccount.Network.Net.Mask, "Network Mask mismatch") + }) + + // ========== VALIDATE DNS SETTINGS ========== + t.Run("DNSSettings", func(t *testing.T) { + assert.Len(t, retrievedAccount.DNSSettings.DisabledManagementGroups, 2, "DisabledManagementGroups length mismatch") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-1", "Missing dns-group-1") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-2", "Missing dns-group-2") + }) + + // ========== VALIDATE SETTINGS ========== + t.Run("Settings", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Settings, "Settings should not be nil") + assert.True(t, retrievedAccount.Settings.PeerLoginExpirationEnabled, "PeerLoginExpirationEnabled mismatch") + assert.Equal(t, time.Hour*24*30, retrievedAccount.Settings.PeerLoginExpiration, "PeerLoginExpiration mismatch") + assert.True(t, retrievedAccount.Settings.GroupsPropagationEnabled, "GroupsPropagationEnabled mismatch") + assert.True(t, retrievedAccount.Settings.JWTGroupsEnabled, "JWTGroupsEnabled mismatch") + assert.Equal(t, "groups", retrievedAccount.Settings.JWTGroupsClaimName, "JWTGroupsClaimName mismatch") + assert.Len(t, retrievedAccount.Settings.JWTAllowGroups, 2, "JWTAllowGroups length mismatch") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-1") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-2") + assert.False(t, retrievedAccount.Settings.RegularUsersViewBlocked, "RegularUsersViewBlocked mismatch") + + // Validate Extra Settings + require.NotNil(t, retrievedAccount.Settings.Extra, "Extra settings should not be nil") + assert.True(t, retrievedAccount.Settings.Extra.PeerApprovalEnabled, "PeerApprovalEnabled mismatch") + assert.Len(t, retrievedAccount.Settings.Extra.IntegratedValidatorGroups, 1, "IntegratedValidatorGroups length mismatch") + assert.Equal(t, "validator-1", retrievedAccount.Settings.Extra.IntegratedValidatorGroups[0]) + }) + + // ========== VALIDATE SETUP KEYS ========== + t.Run("SetupKeys", func(t *testing.T) { + require.Len(t, retrievedAccount.SetupKeys, 2, "Should have 2 setup keys") + + // Validate Setup Key 1 + sk1, exists := retrievedAccount.SetupKeys["setup-key-secret-1"] + require.True(t, exists, "Setup key 1 should exist") + assert.Equal(t, "Setup Key 1", sk1.Name, "Setup key 1 name mismatch") + assert.Equal(t, types.SetupKeyReusable, sk1.Type, "Setup key 1 type mismatch") + assert.False(t, sk1.Revoked, "Setup key 1 should not be revoked") + assert.Equal(t, 5, sk1.UsedTimes, "Setup key 1 used times mismatch") + assert.Equal(t, 100, sk1.UsageLimit, "Setup key 1 usage limit mismatch") + assert.False(t, sk1.Ephemeral, "Setup key 1 should not be ephemeral") + assert.Len(t, sk1.AutoGroups, 2, "Setup key 1 auto groups length mismatch") + assert.Contains(t, sk1.AutoGroups, groupID1) + assert.Contains(t, sk1.AutoGroups, groupID2) + + // Validate Setup Key 2 + sk2, exists := retrievedAccount.SetupKeys["setup-key-secret-2"] + require.True(t, exists, "Setup key 2 should exist") + assert.Equal(t, "Setup Key 2 (One-off)", sk2.Name, "Setup key 2 name mismatch") + assert.Equal(t, types.SetupKeyOneOff, sk2.Type, "Setup key 2 type mismatch") + assert.True(t, sk2.Revoked, "Setup key 2 should be revoked") + assert.Equal(t, 1, sk2.UsedTimes, "Setup key 2 used times mismatch") + assert.Equal(t, 1, sk2.UsageLimit, "Setup key 2 usage limit mismatch") + assert.True(t, sk2.Ephemeral, "Setup key 2 should be ephemeral") + assert.Len(t, sk2.AutoGroups, 0, "Setup key 2 should have empty auto groups") + }) + + // ========== VALIDATE PEERS ========== + t.Run("Peers", func(t *testing.T) { + require.Len(t, retrievedAccount.Peers, 3, "Should have 3 peers") + + // Validate Peer 1 + p1, exists := retrievedAccount.Peers[peerID1] + require.True(t, exists, "Peer 1 should exist") + assert.Equal(t, "Peer 1", p1.Name, "Peer 1 name mismatch") + assert.Equal(t, "peer-key-1-AAAA", p1.Key, "Peer 1 key mismatch") + assert.True(t, p1.IP.Equal(net.ParseIP("100.64.0.1")), "Peer 1 IP mismatch") + assert.Equal(t, userID1, p1.UserID, "Peer 1 user ID mismatch") + assert.True(t, p1.SSHEnabled, "Peer 1 SSH should be enabled") + assert.Equal(t, "ssh-rsa AAAAB3NzaC1...", p1.SSHKey, "Peer 1 SSH key mismatch") + assert.True(t, p1.LoginExpirationEnabled, "Peer 1 login expiration should be enabled") + assert.False(t, p1.Ephemeral, "Peer 1 should not be ephemeral") + assert.Equal(t, "peer1", p1.DNSLabel, "Peer 1 DNS label mismatch") + + // Validate Peer 1 Meta + assert.Equal(t, "peer1.example.com", p1.Meta.Hostname, "Peer 1 hostname mismatch") + assert.Equal(t, "linux", p1.Meta.GoOS, "Peer 1 OS mismatch") + assert.Equal(t, "5.15.0", p1.Meta.Kernel, "Peer 1 kernel mismatch") + assert.Equal(t, "x86_64", p1.Meta.Core, "Peer 1 core mismatch") + assert.Equal(t, "ubuntu", p1.Meta.Platform, "Peer 1 platform mismatch") + assert.Equal(t, "Ubuntu 22.04", p1.Meta.OS, "Peer 1 OS version mismatch") + assert.Equal(t, "0.24.0", p1.Meta.WtVersion, "Peer 1 wt version mismatch") + assert.Equal(t, "ABC123", p1.Meta.SystemSerialNumber, "Peer 1 serial number mismatch") + assert.Equal(t, "Server Model X", p1.Meta.SystemProductName, "Peer 1 product name mismatch") + assert.Equal(t, "Dell Inc.", p1.Meta.SystemManufacturer, "Peer 1 manufacturer mismatch") + + // Validate Network Addresses + assert.Len(t, p1.Meta.NetworkAddresses, 2, "Peer 1 should have 2 network addresses") + assert.Equal(t, netip.MustParsePrefix("192.168.1.10/32"), p1.Meta.NetworkAddresses[0].NetIP, "Network address 1 IP mismatch") + assert.Equal(t, "00:11:22:33:44:55", p1.Meta.NetworkAddresses[0].Mac, "Network address 1 MAC mismatch") + assert.Equal(t, netip.MustParsePrefix("10.0.0.5/32"), p1.Meta.NetworkAddresses[1].NetIP, "Network address 2 IP mismatch") + assert.Equal(t, "00:11:22:33:44:66", p1.Meta.NetworkAddresses[1].Mac, "Network address 2 MAC mismatch") + + // Validate Peer 1 Status + require.NotNil(t, p1.Status, "Peer 1 status should not be nil") + assert.True(t, p1.Status.Connected, "Peer 1 should be connected") + assert.False(t, p1.Status.LoginExpired, "Peer 1 login should not be expired") + assert.False(t, p1.Status.RequiresApproval, "Peer 1 should not require approval") + + // Validate Peer 1 Location + assert.True(t, p1.Location.ConnectionIP.Equal(net.ParseIP("203.0.113.10")), "Peer 1 connection IP mismatch") + assert.Equal(t, "US", p1.Location.CountryCode, "Peer 1 country code mismatch") + assert.Equal(t, "San Francisco", p1.Location.CityName, "Peer 1 city name mismatch") + assert.Equal(t, uint(5391959), p1.Location.GeoNameID, "Peer 1 geo name ID mismatch") + + // Validate Peer 2 + p2, exists := retrievedAccount.Peers[peerID2] + require.True(t, exists, "Peer 2 should exist") + assert.Equal(t, "Peer 2", p2.Name, "Peer 2 name mismatch") + assert.Equal(t, "peer-key-2-BBBB", p2.Key, "Peer 2 key mismatch") + assert.False(t, p2.SSHEnabled, "Peer 2 SSH should be disabled") + assert.False(t, p2.LoginExpirationEnabled, "Peer 2 login expiration should be disabled") + assert.True(t, p2.InactivityExpirationEnabled, "Peer 2 inactivity expiration should be enabled") + + // Validate Peer 2 Status + require.NotNil(t, p2.Status, "Peer 2 status should not be nil") + assert.False(t, p2.Status.Connected, "Peer 2 should not be connected") + assert.True(t, p2.Status.LoginExpired, "Peer 2 login should be expired") + assert.True(t, p2.Status.RequiresApproval, "Peer 2 should require approval") + + // Validate Peer 3 (Ephemeral) + p3, exists := retrievedAccount.Peers[peerID3] + require.True(t, exists, "Peer 3 should exist") + assert.True(t, p3.Ephemeral, "Peer 3 should be ephemeral") + assert.Equal(t, "Peer 3 (Ephemeral)", p3.Name, "Peer 3 name mismatch") + }) + + // ========== VALIDATE USERS ========== + t.Run("Users", func(t *testing.T) { + require.Len(t, retrievedAccount.Users, 2, "Should have 2 users") + + // Validate User 1 + u1, exists := retrievedAccount.Users[userID1] + require.True(t, exists, "User 1 should exist") + assert.Equal(t, types.UserRoleOwner, u1.Role, "User 1 role mismatch") + assert.False(t, u1.IsServiceUser, "User 1 should not be a service user") + assert.True(t, u1.NonDeletable, "User 1 should be non-deletable") + assert.Equal(t, types.UserIssuedAPI, u1.Issued, "User 1 issued type mismatch") + assert.Len(t, u1.AutoGroups, 1, "User 1 auto groups length mismatch") + assert.Contains(t, u1.AutoGroups, groupID1, "User 1 should have group1") + assert.False(t, u1.Blocked, "User 1 should not be blocked") + require.NotNil(t, u1.LastLogin, "User 1 last login should not be nil") + assert.WithinDuration(t, lastLogin, *u1.LastLogin, time.Second, "User 1 last login mismatch") + + // Validate User 1 Integration Reference + assert.Equal(t, 123, u1.IntegrationReference.ID, "User 1 integration ID mismatch") + assert.Equal(t, "azure_ad", u1.IntegrationReference.IntegrationType, "User 1 integration type mismatch") + + // Validate User 1 PATs + require.Len(t, u1.PATs, 2, "User 1 should have 2 PATs") + + pat1Retrieved, exists := u1.PATs[patID1] + require.True(t, exists, "PAT 1 should exist") + assert.Equal(t, "PAT 1", pat1Retrieved.Name, "PAT 1 name mismatch") + assert.Equal(t, "hashed-token-1", pat1Retrieved.HashedToken, "PAT 1 hashed token mismatch") + require.NotNil(t, pat1Retrieved.LastUsed, "PAT 1 last used should not be nil") + assert.WithinDuration(t, patLastUsed, *pat1Retrieved.LastUsed, time.Second, "PAT 1 last used mismatch") + assert.Equal(t, userID1, pat1Retrieved.CreatedBy, "PAT 1 created by mismatch") + assert.Empty(t, pat1Retrieved.UserID, "PAT 1 UserID should be cleared") + + pat2Retrieved, exists := u1.PATs[patID2] + require.True(t, exists, "PAT 2 should exist") + assert.Equal(t, "PAT 2", pat2Retrieved.Name, "PAT 2 name mismatch") + assert.Nil(t, pat2Retrieved.LastUsed, "PAT 2 last used should be nil") + + // Validate User 2 + u2, exists := retrievedAccount.Users[userID2] + require.True(t, exists, "User 2 should exist") + assert.Equal(t, types.UserRoleAdmin, u2.Role, "User 2 role mismatch") + assert.True(t, u2.IsServiceUser, "User 2 should be a service user") + assert.False(t, u2.NonDeletable, "User 2 should be deletable") + assert.Equal(t, types.UserIssuedIntegration, u2.Issued, "User 2 issued type mismatch") + assert.Equal(t, "google_workspace", u2.IntegrationReference.IntegrationType, "User 2 integration type mismatch") + + // Validate User 2 PATs + require.Len(t, u2.PATs, 1, "User 2 should have 1 PAT") + pat3Retrieved, exists := u2.PATs[patID3] + require.True(t, exists, "PAT 3 should exist") + assert.Equal(t, "PAT 3", pat3Retrieved.Name, "PAT 3 name mismatch") + }) + + // ========== VALIDATE GROUPS ========== + t.Run("Groups", func(t *testing.T) { + require.Len(t, retrievedAccount.Groups, 2, "Should have 2 groups") + + // Validate Group 1 + g1, exists := retrievedAccount.Groups[groupID1] + require.True(t, exists, "Group 1 should exist") + assert.Equal(t, "Group 1", g1.Name, "Group 1 name mismatch") + assert.Equal(t, types.GroupIssuedAPI, g1.Issued, "Group 1 issued type mismatch") + assert.Len(t, g1.Peers, 2, "Group 1 should have 2 peers") + assert.Contains(t, g1.Peers, peerID1, "Group 1 should contain peer 1") + assert.Contains(t, g1.Peers, peerID2, "Group 1 should contain peer 2") + + // Validate Group 1 Resources + assert.Len(t, g1.Resources, 1, "Group 1 should have 1 resource") + assert.Equal(t, "resource-1", g1.Resources[0].ID, "Group 1 resource ID mismatch") + assert.Equal(t, types.ResourceTypeHost, g1.Resources[0].Type, "Group 1 resource type mismatch") + + // Validate Group 2 + g2, exists := retrievedAccount.Groups[groupID2] + require.True(t, exists, "Group 2 should exist") + assert.Equal(t, "Group 2", g2.Name, "Group 2 name mismatch") + assert.Equal(t, types.GroupIssuedIntegration, g2.Issued, "Group 2 issued type mismatch") + assert.Len(t, g2.Peers, 1, "Group 2 should have 1 peer") + assert.Contains(t, g2.Peers, peerID3, "Group 2 should contain peer 3") + assert.Len(t, g2.Resources, 0, "Group 2 should have 0 resources") + + // Validate Group 2 Integration Reference + assert.Equal(t, 789, g2.IntegrationReference.ID, "Group 2 integration ID mismatch") + assert.Equal(t, "okta", g2.IntegrationReference.IntegrationType, "Group 2 integration type mismatch") + }) + + // ========== VALIDATE POLICIES ========== + t.Run("Policies", func(t *testing.T) { + require.Len(t, retrievedAccount.Policies, 2, "Should have 2 policies") + + // Validate Policy 1 + pol1 := retrievedAccount.Policies[0] + if pol1.ID != policyID1 { + pol1 = retrievedAccount.Policies[1] + } + assert.Equal(t, policyID1, pol1.ID, "Policy 1 ID mismatch") + assert.Equal(t, "Policy 1", pol1.Name, "Policy 1 name mismatch") + assert.Equal(t, "Main access policy", pol1.Description, "Policy 1 description mismatch") + assert.True(t, pol1.Enabled, "Policy 1 should be enabled") + + // Validate Policy 1 Rules + require.Len(t, pol1.Rules, 2, "Policy 1 should have 2 rules") + + rule1 := pol1.Rules[0] + assert.Equal(t, "Rule 1", rule1.Name, "Rule 1 name mismatch") + assert.Equal(t, "Allow access", rule1.Description, "Rule 1 description mismatch") + assert.True(t, rule1.Enabled, "Rule 1 should be enabled") + assert.Equal(t, types.PolicyTrafficActionAccept, rule1.Action, "Rule 1 action mismatch") + assert.True(t, rule1.Bidirectional, "Rule 1 should be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolALL, rule1.Protocol, "Rule 1 protocol mismatch") + assert.Len(t, rule1.Sources, 1, "Rule 1 sources length mismatch") + assert.Contains(t, rule1.Sources, groupID1, "Rule 1 should have group1 as source") + assert.Len(t, rule1.Destinations, 1, "Rule 1 destinations length mismatch") + assert.Contains(t, rule1.Destinations, groupID2, "Rule 1 should have group2 as destination") + + rule2 := pol1.Rules[1] + assert.Equal(t, "Rule 2", rule2.Name, "Rule 2 name mismatch") + assert.Equal(t, types.PolicyTrafficActionDrop, rule2.Action, "Rule 2 action mismatch") + assert.False(t, rule2.Bidirectional, "Rule 2 should not be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolTCP, rule2.Protocol, "Rule 2 protocol mismatch") + assert.Len(t, rule2.Ports, 2, "Rule 2 ports length mismatch") + assert.Contains(t, rule2.Ports, "22", "Rule 2 should have port 22") + assert.Contains(t, rule2.Ports, "3389", "Rule 2 should have port 3389") + assert.Len(t, rule2.PortRanges, 1, "Rule 2 port ranges length mismatch") + assert.Equal(t, uint16(8000), rule2.PortRanges[0].Start, "Rule 2 port range start mismatch") + assert.Equal(t, uint16(8999), rule2.PortRanges[0].End, "Rule 2 port range end mismatch") + + // Validate Policy 2 + pol2 := retrievedAccount.Policies[1] + if pol2.ID != policyID2 { + pol2 = retrievedAccount.Policies[0] + } + assert.Equal(t, policyID2, pol2.ID, "Policy 2 ID mismatch") + assert.Equal(t, "Policy 2", pol2.Name, "Policy 2 name mismatch") + assert.False(t, pol2.Enabled, "Policy 2 should be disabled") + require.Len(t, pol2.Rules, 1, "Policy 2 should have 1 rule") + + rule3 := pol2.Rules[0] + assert.Equal(t, "Rule 3", rule3.Name, "Rule 3 name mismatch") + assert.False(t, rule3.Enabled, "Rule 3 should be disabled") + assert.Equal(t, types.PolicyRuleProtocolUDP, rule3.Protocol, "Rule 3 protocol mismatch") + }) + + // ========== VALIDATE ROUTES ========== + t.Run("Routes", func(t *testing.T) { + require.Len(t, retrievedAccount.Routes, 2, "Should have 2 routes") + + // Validate Route 1 + r1, exists := retrievedAccount.Routes[routeID1] + require.True(t, exists, "Route 1 should exist") + assert.Equal(t, "Route 1", r1.Description, "Route 1 description mismatch") + assert.Equal(t, route.IPv4Network, r1.NetworkType, "Route 1 network type mismatch") + assert.Equal(t, peerID1, r1.Peer, "Route 1 peer mismatch") + assert.Empty(t, r1.PeerGroups, "Route 1 peer groups should be empty") + assert.Equal(t, route.NetID("net-id-1"), r1.NetID, "Route 1 net ID mismatch") + assert.True(t, r1.Masquerade, "Route 1 masquerade should be enabled") + assert.Equal(t, 9999, r1.Metric, "Route 1 metric mismatch") + assert.True(t, r1.Enabled, "Route 1 should be enabled") + assert.Len(t, r1.Groups, 1, "Route 1 groups length mismatch") + assert.Contains(t, r1.Groups, groupID1, "Route 1 should have group1") + assert.Len(t, r1.AccessControlGroups, 1, "Route 1 ACL groups length mismatch") + assert.Contains(t, r1.AccessControlGroups, groupID2, "Route 1 should have group2 in ACL") + + // Validate Route 1 Network CIDR + assert.Equal(t, "10.0.0.0/24", r1.Network.String(), "Route 1 network CIDR mismatch") + + // Validate Route 2 + r2, exists := retrievedAccount.Routes[routeID2] + require.True(t, exists, "Route 2 should exist") + assert.Equal(t, "Route 2 (High Availability)", r2.Description, "Route 2 description mismatch") + assert.Empty(t, r2.Peer, "Route 2 peer should be empty") + assert.Len(t, r2.PeerGroups, 1, "Route 2 peer groups length mismatch") + assert.Contains(t, r2.PeerGroups, groupID2, "Route 2 should have group2 as peer group") + assert.False(t, r2.Masquerade, "Route 2 masquerade should be disabled") + assert.Equal(t, 100, r2.Metric, "Route 2 metric mismatch") + assert.Equal(t, "192.168.1.0/24", r2.Network.String(), "Route 2 network CIDR mismatch") + }) + + // ========== VALIDATE NAME SERVER GROUPS ========== + t.Run("NameServerGroups", func(t *testing.T) { + require.Len(t, retrievedAccount.NameServerGroups, 2, "Should have 2 nameserver groups") + + // Validate NS Group 1 + nsg1, exists := retrievedAccount.NameServerGroups[nsGroupID1] + require.True(t, exists, "NS Group 1 should exist") + assert.Equal(t, "NS Group 1", nsg1.Name, "NS Group 1 name mismatch") + assert.Equal(t, "Primary nameservers", nsg1.Description, "NS Group 1 description mismatch") + assert.True(t, nsg1.Enabled, "NS Group 1 should be enabled") + assert.True(t, nsg1.Primary, "NS Group 1 should be primary") + assert.True(t, nsg1.SearchDomainsEnabled, "NS Group 1 search domains should be enabled") + assert.Empty(t, nsg1.AccountID, "NS Group 1 AccountID should be cleared") + + // Validate NS Group 1 NameServers + require.Len(t, nsg1.NameServers, 2, "NS Group 1 should have 2 nameservers") + assert.Equal(t, netip.MustParseAddr("8.8.8.8"), nsg1.NameServers[0].IP, "NS Group 1 nameserver 1 IP mismatch") + assert.Equal(t, nbdns.UDPNameServerType, nsg1.NameServers[0].NSType, "NS Group 1 nameserver 1 type mismatch") + assert.Equal(t, 53, nsg1.NameServers[0].Port, "NS Group 1 nameserver 1 port mismatch") + assert.Equal(t, netip.MustParseAddr("8.8.4.4"), nsg1.NameServers[1].IP, "NS Group 1 nameserver 2 IP mismatch") + + // Validate NS Group 1 Groups and Domains + assert.Len(t, nsg1.Groups, 2, "NS Group 1 groups length mismatch") + assert.Contains(t, nsg1.Groups, groupID1, "NS Group 1 should have group1") + assert.Contains(t, nsg1.Groups, groupID2, "NS Group 1 should have group2") + assert.Len(t, nsg1.Domains, 2, "NS Group 1 domains length mismatch") + assert.Contains(t, nsg1.Domains, "example.com", "NS Group 1 should have example.com domain") + assert.Contains(t, nsg1.Domains, "test.com", "NS Group 1 should have test.com domain") + + // Validate NS Group 2 + nsg2, exists := retrievedAccount.NameServerGroups[nsGroupID2] + require.True(t, exists, "NS Group 2 should exist") + assert.Equal(t, "NS Group 2", nsg2.Name, "NS Group 2 name mismatch") + assert.False(t, nsg2.Enabled, "NS Group 2 should be disabled") + assert.False(t, nsg2.Primary, "NS Group 2 should not be primary") + assert.False(t, nsg2.SearchDomainsEnabled, "NS Group 2 search domains should be disabled") + assert.Len(t, nsg2.NameServers, 1, "NS Group 2 should have 1 nameserver") + assert.Len(t, nsg2.Groups, 0, "NS Group 2 should have empty groups") + assert.Len(t, nsg2.Domains, 0, "NS Group 2 should have empty domains") + }) + + // ========== VALIDATE POSTURE CHECKS ========== + t.Run("PostureChecks", func(t *testing.T) { + require.Len(t, retrievedAccount.PostureChecks, 2, "Should have 2 posture checks") + + // Find posture checks by ID + var pc1, pc2 *posture.Checks + for _, pc := range retrievedAccount.PostureChecks { + if pc.ID == postureCheckID1 { + pc1 = pc + } else if pc.ID == postureCheckID2 { + pc2 = pc + } + } + + // Validate Posture Check 1 + require.NotNil(t, pc1, "Posture check 1 should exist") + assert.Equal(t, "Posture Check 1", pc1.Name, "Posture check 1 name mismatch") + assert.Equal(t, "OS version check", pc1.Description, "Posture check 1 description mismatch") + + // Validate NB Version Check + require.NotNil(t, pc1.Checks.NBVersionCheck, "NB version check should not be nil") + assert.Equal(t, "0.24.0", pc1.Checks.NBVersionCheck.MinVersion, "NB version check min version mismatch") + + // Validate OS Version Check + require.NotNil(t, pc1.Checks.OSVersionCheck, "OS version check should not be nil") + require.NotNil(t, pc1.Checks.OSVersionCheck.Ios, "iOS version check should not be nil") + assert.Equal(t, "16.0", pc1.Checks.OSVersionCheck.Ios.MinVersion, "iOS min version mismatch") + require.NotNil(t, pc1.Checks.OSVersionCheck.Darwin, "Darwin version check should not be nil") + assert.Equal(t, "22.0.0", pc1.Checks.OSVersionCheck.Darwin.MinVersion, "Darwin min version mismatch") + + // Validate Posture Check 2 + require.NotNil(t, pc2, "Posture check 2 should exist") + assert.Equal(t, "Posture Check 2", pc2.Name, "Posture check 2 name mismatch") + + // Validate Geo Location Check + require.NotNil(t, pc2.Checks.GeoLocationCheck, "Geo location check should not be nil") + assert.Equal(t, "allow", pc2.Checks.GeoLocationCheck.Action, "Geo location action mismatch") + assert.Len(t, pc2.Checks.GeoLocationCheck.Locations, 2, "Geo location check should have 2 locations") + assert.Equal(t, "US", pc2.Checks.GeoLocationCheck.Locations[0].CountryCode, "Location 1 country code mismatch") + assert.Equal(t, "San Francisco", pc2.Checks.GeoLocationCheck.Locations[0].CityName, "Location 1 city name mismatch") + assert.Equal(t, "GB", pc2.Checks.GeoLocationCheck.Locations[1].CountryCode, "Location 2 country code mismatch") + assert.Equal(t, "London", pc2.Checks.GeoLocationCheck.Locations[1].CityName, "Location 2 city name mismatch") + + // Validate Peer Network Range Check + require.NotNil(t, pc2.Checks.PeerNetworkRangeCheck, "Peer network range check should not be nil") + assert.Equal(t, "allow", pc2.Checks.PeerNetworkRangeCheck.Action, "Peer network range action mismatch") + assert.Len(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, 2, "Peer network range check should have 2 ranges") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("192.168.0.0/16"), "Should have 192.168.0.0/16 range") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("10.0.0.0/8"), "Should have 10.0.0.0/8 range") + }) + + // ========== VALIDATE NETWORKS ========== + t.Run("Networks", func(t *testing.T) { + require.Len(t, retrievedAccount.Networks, 1, "Should have 1 network") + + net1 := retrievedAccount.Networks[0] + assert.Equal(t, networkID1, net1.ID, "Network 1 ID mismatch") + assert.Equal(t, "Network 1", net1.Name, "Network 1 name mismatch") + assert.Equal(t, "Primary network", net1.Description, "Network 1 description mismatch") + }) + + // ========== VALIDATE NETWORK ROUTERS ========== + t.Run("NetworkRouters", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkRouters, 1, "Should have 1 network router") + + router := retrievedAccount.NetworkRouters[0] + assert.Equal(t, routerID1, router.ID, "Router 1 ID mismatch") + assert.Equal(t, networkID1, router.NetworkID, "Router 1 network ID mismatch") + assert.Equal(t, peerID1, router.Peer, "Router 1 peer mismatch") + assert.Empty(t, router.PeerGroups, "Router 1 peer groups should be empty") + assert.True(t, router.Masquerade, "Router 1 masquerade should be enabled") + assert.Equal(t, 100, router.Metric, "Router 1 metric mismatch") + }) + + // ========== VALIDATE NETWORK RESOURCES ========== + t.Run("NetworkResources", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkResources, 1, "Should have 1 network resource") + + res := retrievedAccount.NetworkResources[0] + assert.Equal(t, resourceID1, res.ID, "Resource 1 ID mismatch") + assert.Equal(t, networkID1, res.NetworkID, "Resource 1 network ID mismatch") + assert.Equal(t, "Resource 1", res.Name, "Resource 1 name mismatch") + assert.Equal(t, "Web server", res.Description, "Resource 1 description mismatch") + assert.Equal(t, netip.MustParsePrefix("192.168.1.100/32"), res.Prefix, "Resource 1 prefix mismatch") + assert.Equal(t, resourceTypes.Host, res.Type, "Resource 1 type mismatch") + }) + + // ========== VALIDATE ONBOARDING ========== + t.Run("Onboarding", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Onboarding.AccountID, "Onboarding account ID mismatch") + assert.True(t, retrievedAccount.Onboarding.OnboardingFlowPending, "Onboarding flow should be pending") + assert.False(t, retrievedAccount.Onboarding.SignupFormPending, "Signup form should not be pending") + assert.WithinDuration(t, now, retrievedAccount.Onboarding.CreatedAt, time.Second, "Onboarding created at mismatch") + }) + + t.Log("✅ All comprehensive account field validations passed!") +} diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go new file mode 100644 index 000000000..350a1da83 --- /dev/null +++ b/management/server/store/sqlstore_bench_test.go @@ -0,0 +1,951 @@ +package store + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sort" + "sync" + "testing" + "time" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/jackc/pgx/v5/pgxpool" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/testutil" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/status" +) + +func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Omit("GroupsG"). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload(clause.Associations). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us + for i, policy := range account.Policies { + var rules []*types.PolicyRule + err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + if err != nil { + return nil, status.Errorf(status.NotFound, "rule not found") + } + account.Policies[i].Rules = rules + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + account.SetupKeys[key.Key] = key.Copy() + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = peer.Copy() + } + account.PeersG = nil + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = pat.Copy() + } + account.Users[user.Id] = user.Copy() + } + account.UsersG = nil + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + account.Groups[group.ID] = group.Copy() + } + account.GroupsG = nil + + var groupPeers []types.GroupPeer + s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). + Find(&groupPeers) + for _, groupPeer := range groupPeers { + if group, ok := account.Groups[groupPeer.GroupID]; ok { + group.Peers = append(group.Peers, groupPeer.PeerID) + } else { + log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + } + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = route.Copy() + } + account.RoutesG = nil + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + account.NameServerGroups[ns.ID] = ns.Copy() + } + account.NameServerGroupsG = nil + + return &account, nil +} + +func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = &peer + } + account.PeersG = nil + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + pat.UserID = "" + user.PATs[pat.ID] = &pat + } + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil + } + account.UsersG = nil + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group + } + account.GroupsG = nil + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + return &account, nil +} + +func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 12 + config.MinConns = 2 + config.MaxConnLifetime = time.Hour + config.HealthCheckPeriod = time.Minute + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + return pool, nil +} + +func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) { + cleanup, dsn, err := testutil.CreatePostgresTestContainer() + if err != nil { + b.Fatalf("failed to create test container: %v", err) + } + + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + pool, err := connectDBforTest(context.Background(), dsn) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + models := []interface{}{ + &types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, + &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, + &types.Policy{}, &types.PolicyRule{}, &route.Route{}, + &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{}, + &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, + &types.AccountOnboarding{}, + } + + for i := len(models) - 1; i >= 0; i-- { + err := db.Migrator().DropTable(models[i]) + if err != nil { + b.Fatalf("failed to drop table: %v", err) + } + } + + err = db.AutoMigrate(models...) + if err != nil { + b.Fatalf("failed to migrate database: %v", err) + } + + store := &SqlStore{ + db: db, + pool: pool, + } + + const ( + accountID = "benchmark-account-id" + numUsers = 20 + numPatsPerUser = 3 + numSetupKeys = 25 + numPeers = 200 + numGroups = 30 + numPolicies = 50 + numRulesPerPolicy = 10 + numRoutes = 40 + numNSGroups = 10 + numPostureChecks = 15 + numNetworks = 5 + numNetworkRouters = 5 + numNetworkResources = 10 + ) + + _, ipNet, _ := net.ParseCIDR("100.64.0.0/10") + acc := types.Account{ + Id: accountID, + CreatedBy: "benchmark-user", + CreatedAt: time.Now(), + Domain: "benchmark.com", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "benchmark-net", + Net: *ipNet, + Serial: 1, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"group-disabled-1"}, + }, + Settings: &types.Settings{}, + } + if err := db.Create(&acc).Error; err != nil { + b.Fatalf("create account: %v", err) + } + + var setupKeys []types.SetupKey + for i := 0; i < numSetupKeys; i++ { + setupKeys = append(setupKeys, types.SetupKey{ + Id: fmt.Sprintf("keyid-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("key-%d", i), + Name: fmt.Sprintf("Benchmark Key %d", i), + ExpiresAt: &time.Time{}, + }) + } + if err := db.Create(&setupKeys).Error; err != nil { + b.Fatalf("create setup keys: %v", err) + } + + var peers []nbpeer.Peer + for i := 0; i < numPeers; i++ { + peers = append(peers, nbpeer.Peer{ + ID: fmt.Sprintf("peer-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("peerkey-%d", i), + IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)), + Name: fmt.Sprintf("peer-name-%d", i), + Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()}, + }) + } + if err := db.Create(&peers).Error; err != nil { + b.Fatalf("create peers: %v", err) + } + + for i := 0; i < numUsers; i++ { + userID := fmt.Sprintf("user-%d", i) + user := types.User{Id: userID, AccountID: accountID} + if err := db.Create(&user).Error; err != nil { + b.Fatalf("create user %s: %v", userID, err) + } + + var pats []types.PersonalAccessToken + for j := 0; j < numPatsPerUser; j++ { + pats = append(pats, types.PersonalAccessToken{ + ID: fmt.Sprintf("pat-%d-%d", i, j), + UserID: userID, + Name: fmt.Sprintf("PAT %d for User %d", j, i), + }) + } + if err := db.Create(&pats).Error; err != nil { + b.Fatalf("create pats for user %s: %v", userID, err) + } + } + + var groups []*types.Group + for i := 0; i < numGroups; i++ { + groups = append(groups, &types.Group{ + ID: fmt.Sprintf("group-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Group %d", i), + }) + } + if err := db.Create(&groups).Error; err != nil { + b.Fatalf("create groups: %v", err) + } + + for i := 0; i < numPolicies; i++ { + policyID := fmt.Sprintf("policy-%d", i) + policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true} + if err := db.Create(&policy).Error; err != nil { + b.Fatalf("create policy %s: %v", policyID, err) + } + + var rules []*types.PolicyRule + for j := 0; j < numRulesPerPolicy; j++ { + rules = append(rules, &types.PolicyRule{ + ID: fmt.Sprintf("rule-%d-%d", i, j), + PolicyID: policyID, + Name: fmt.Sprintf("Rule %d for Policy %d", j, i), + Enabled: true, + Protocol: "all", + }) + } + if err := db.Create(&rules).Error; err != nil { + b.Fatalf("create rules for policy %s: %v", policyID, err) + } + } + + var routes []route.Route + for i := 0; i < numRoutes; i++ { + routes = append(routes, route.Route{ + ID: route.ID(fmt.Sprintf("route-%d", i)), + AccountID: accountID, + Description: fmt.Sprintf("Route %d", i), + Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)), + Enabled: true, + }) + } + if err := db.Create(&routes).Error; err != nil { + b.Fatalf("create routes: %v", err) + } + + var nsGroups []nbdns.NameServerGroup + for i := 0; i < numNSGroups; i++ { + nsGroups = append(nsGroups, nbdns.NameServerGroup{ + ID: fmt.Sprintf("nsg-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("NS Group %d", i), + Description: "Benchmark NS Group", + Enabled: true, + }) + } + if err := db.Create(&nsGroups).Error; err != nil { + b.Fatalf("create nsgroups: %v", err) + } + + var postureChecks []*posture.Checks + for i := 0; i < numPostureChecks; i++ { + postureChecks = append(postureChecks, &posture.Checks{ + ID: fmt.Sprintf("pc-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Posture Check %d", i), + }) + } + if err := db.Create(&postureChecks).Error; err != nil { + b.Fatalf("create posture checks: %v", err) + } + + var networks []*networkTypes.Network + for i := 0; i < numNetworks; i++ { + networks = append(networks, &networkTypes.Network{ + ID: fmt.Sprintf("nettype-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Network Type %d", i), + }) + } + if err := db.Create(&networks).Error; err != nil { + b.Fatalf("create networks: %v", err) + } + + var networkRouters []*routerTypes.NetworkRouter + for i := 0; i < numNetworkRouters; i++ { + networkRouters = append(networkRouters, &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("router-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Peer: peers[i%numPeers].ID, + }) + } + if err := db.Create(&networkRouters).Error; err != nil { + b.Fatalf("create network routers: %v", err) + } + + var networkResources []*resourceTypes.NetworkResource + for i := 0; i < numNetworkResources; i++ { + networkResources = append(networkResources, &resourceTypes.NetworkResource{ + ID: fmt.Sprintf("resource-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Name: fmt.Sprintf("Resource %d", i), + }) + } + if err := db.Create(&networkResources).Error; err != nil { + b.Fatalf("create network resources: %v", err) + } + + onboarding := types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + } + if err := db.Create(&onboarding).Error; err != nil { + b.Fatalf("create onboarding: %v", err) + } + + return store, cleanup, accountID +} + +func BenchmarkGetAccount(b *testing.B) { + store, cleanup, accountID := setupBenchmarkDB(b) + defer cleanup() + ctx := context.Background() + b.ResetTimer() + b.ReportAllocs() + b.Run("old", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountSlow(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountSlow failed: %v", err) + } + } + }) + b.Run("gorm opt", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountGormOpt(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountFast failed: %v", err) + } + } + }) + b.Run("raw", func(b *testing.B) { + for range b.N { + _, err := store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountPureSQL failed: %v", err) + } + } + }) + store.pool.Close() +} + +func TestAccountEquivalence(t *testing.T) { + store, cleanup, accountID := setupBenchmarkDB(t) + defer cleanup() + ctx := context.Background() + + type getAccountFunc func(context.Context, string) (*types.Account, error) + + tests := []struct { + name string + expectedF getAccountFunc + actualF getAccountFunc + }{ + {"old vs new", store.GetAccountSlow, store.GetAccountGormOpt}, + {"old vs raw", store.GetAccountSlow, store.GetAccount}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected, errOld := tt.expectedF(ctx, accountID) + assert.NoError(t, errOld, "expected function should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := tt.actualF(ctx, accountID) + assert.NoError(t, errNew, "actual function should not return an error") + assert.NotNil(t, actual, "actual should not be nil") + testAccountEquivalence(t, expected, actual) + }) + } + + expected, errOld := store.GetAccountSlow(ctx, accountID) + assert.NoError(t, errOld, "GetAccountSlow should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := store.GetAccount(ctx, accountID) + assert.NoError(t, errNew, "GetAccount (new) should not return an error") + assert.NotNil(t, actual, "actual should not be nil") +} + +func testAccountEquivalence(t *testing.T, expected, actual *types.Account) { + assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal") + assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal") + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second") + assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal") + assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal") + assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal") + assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal") + assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal") + assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal") + + assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements") + for key, oldVal := range expected.SetupKeys { + newVal, ok := actual.SetupKeys[key] + assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key) + } + + assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements") + for key, oldVal := range expected.Peers { + newVal, ok := actual.Peers[key] + assert.True(t, ok, "Peer with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements") + for key, oldUser := range expected.Users { + newUser, ok := actual.Users[key] + assert.True(t, ok, "User with ID '%s' should exist in new account", key) + + assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key) + for patKey, oldPAT := range oldUser.PATs { + newPAT, patOk := newUser.PATs[patKey] + assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key) + assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key) + } + + oldUser.PATs = nil + newUser.PATs = nil + assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key) + } + + assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements") + for key, oldVal := range expected.Groups { + newVal, ok := actual.Groups[key] + assert.True(t, ok, "Group with ID '%s' should exist in new account", key) + sort.Strings(oldVal.Peers) + sort.Strings(newVal.Peers) + assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements") + for key, oldVal := range expected.Routes { + newVal, ok := actual.Routes[key] + assert.True(t, ok, "Route with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key) + } + + assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements") + for key, oldVal := range expected.NameServerGroups { + newVal, ok := actual.NameServerGroups[key] + assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements") + sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID }) + sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID }) + for i := range expected.Policies { + sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID }) + sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID }) + assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID) + } + + assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements") + sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID }) + sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID }) + for i := range expected.PostureChecks { + assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID) + } + + assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements") + sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID }) + sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID }) + for i := range expected.Networks { + assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID) + } + + assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements") + sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID }) + sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID }) + for i := range expected.NetworkRouters { + assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID) + } + + assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements") + sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID }) + sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID }) + for i := range expected.NetworkResources { + assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID) + } +} + +func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route + } + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg + } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil + account.NameServerGroupsG = nil + + return account, nil +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 21b660d96..007e2b739 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -468,6 +468,9 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine) closeConnection := func() { cleanup() store.Close(ctx) + if store.pool != nil { + store.pool.Close() + } } return store, closeConnection, nil @@ -487,12 +490,18 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + db, err := openDBWithRetry(dsn, kind, 5) if err != nil { return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err) } dsn, cleanup, err := createRandomDB(dsn, db, kind) + + sqlDB, _ := db.DB() + if sqlDB != nil { + sqlDB.Close() + } + if err != nil { return nil, nil, err } @@ -519,12 +528,22 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv) } - db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + db, err := openDBWithRetry(dsn, kind, 5) if err != nil { return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err) } + sqlDB, err := db.DB() + if err != nil { + return nil, nil, fmt.Errorf("failed to get underlying sql.DB: %v", err) + } + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(1) + dsn, cleanup, err := createRandomDB(dsn, db, kind) + + sqlDB.Close() + if err != nil { return nil, nil, err } @@ -537,6 +556,31 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine return store, cleanup, nil } +func openDBWithRetry(dsn string, engine types.Engine, maxRetries int) (*gorm.DB, error) { + var db *gorm.DB + var err error + + for i := range maxRetries { + switch engine { + case types.PostgresStoreEngine: + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) + case types.MysqlStoreEngine: + db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + } + + if err == nil { + return db, nil + } + + if i < maxRetries-1 { + waitTime := time.Duration(100*(i+1)) * time.Millisecond + time.Sleep(waitTime) + } + } + + return nil, err +} + func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) { dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_")) @@ -544,21 +588,63 @@ func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func( return "", nil, fmt.Errorf("failed to create database: %v", err) } - var err error + originalDSN := dsn + cleanup := func() { + var dropDB *gorm.DB + var err error + switch engine { case types.PostgresStoreEngine: - err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error + dropDB, err = gorm.Open(postgres.Open(originalDSN), &gorm.Config{ + SkipDefaultTransaction: true, + PrepareStmt: false, + }) + if err != nil { + log.Errorf("failed to connect for dropping database %s: %v", dbName, err) + return + } + defer func() { + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.Close() + } + }() + + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(0) + sqlDB.SetConnMaxLifetime(time.Second) + } + + err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)).Error + case types.MysqlStoreEngine: - // err = killMySQLConnections(dsn, dbName) - err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error + dropDB, err = gorm.Open(mysql.Open(originalDSN+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{ + SkipDefaultTransaction: true, + PrepareStmt: false, + }) + if err != nil { + log.Errorf("failed to connect for dropping database %s: %v", dbName, err) + return + } + defer func() { + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.Close() + } + }() + + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(0) + sqlDB.SetConnMaxLifetime(time.Second) + } + + err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)).Error } + if err != nil { log.Errorf("failed to drop database %s: %v", dbName, err) - panic(err) } - sqlDB, _ := db.DB() - _ = sqlDB.Close() } return replaceDBName(dsn, dbName), cleanup, nil diff --git a/management/server/types/account.go b/management/server/types/account.go index 50bdc6ab3..dd6052498 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -8,6 +8,7 @@ import ( "slices" "strconv" "strings" + "sync" "time" "github.com/hashicorp/go-multierror" @@ -87,6 +88,13 @@ type Account struct { NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` + + NetworkMapCache *NetworkMapBuilder `gorm:"-"` + nmapInitOnce *sync.Once `gorm:"-"` +} + +func (a *Account) InitOnce() { + a.nmapInitOnce = &sync.Once{} } // this class is used by gorm only @@ -257,6 +265,9 @@ func (a *Account) GetPeerNetworkMap( metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("GetPeerNetworkMap: took %s", time.Since(start)) + }() peer := a.Peers[peerID] if peer == nil { @@ -890,6 +901,8 @@ func (a *Account) Copy() *Account { NetworkRouters: networkRouters, NetworkResources: networkResources, Onboarding: a.Onboarding, + NetworkMapCache: a.NetworkMapCache, + nmapInitOnce: a.nmapInitOnce, } } diff --git a/management/server/types/holder.go b/management/server/types/holder.go new file mode 100644 index 000000000..3996db2b6 --- /dev/null +++ b/management/server/types/holder.go @@ -0,0 +1,43 @@ +package types + +import ( + "context" + "sync" +) + +type Holder struct { + mu sync.RWMutex + accounts map[string]*Account +} + +func NewHolder() *Holder { + return &Holder{ + accounts: make(map[string]*Account), + } +} + +func (h *Holder) GetAccount(id string) *Account { + h.mu.RLock() + defer h.mu.RUnlock() + return h.accounts[id] +} + +func (h *Holder) AddAccount(account *Account) { + h.mu.Lock() + defer h.mu.Unlock() + h.accounts[account.Id] = account +} + +func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) { + h.mu.Lock() + defer h.mu.Unlock() + if acc, ok := h.accounts[id]; ok { + return acc, nil + } + account, err := accGetter(context.Background(), id) + if err != nil { + return nil, err + } + h.accounts[id] = account + return account, nil +} diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go new file mode 100644 index 000000000..c1099726f --- /dev/null +++ b/management/server/types/networkmap.go @@ -0,0 +1,58 @@ +package types + +import ( + "context" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" +) + +func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) { + if a.NetworkMapCache != nil { + return + } + a.nmapInitOnce.Do(func() { + a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers) + }) +} + +func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) { + a.initNetworkMapBuilder(validatedPeers) +} + +func (a *Account) GetPeerNetworkMapExp( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + validatedPeers map[string]struct{}, + metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + a.initNetworkMapBuilder(validatedPeers) + return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics) +} + +func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error { + if a.NetworkMapCache == nil { + return nil + } + return a.NetworkMapCache.OnPeerAddedIncremental(peerId) +} + +func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error { + if a.NetworkMapCache == nil { + return nil + } + return a.NetworkMapCache.OnPeerDeleted(peerId) +} + +func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) { + if a.NetworkMapCache == nil { + return + } + a.NetworkMapCache.UpdatePeer(peer) +} + +func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) { + a.initNetworkMapBuilder(validatedPeers) +} diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go new file mode 100644 index 000000000..d85aaabb2 --- /dev/null +++ b/management/server/types/networkmap_golden_test.go @@ -0,0 +1,1069 @@ +package types_test + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/netip" + "os" + "path/filepath" + "slices" + "sort" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// update flag is used to update the golden file. +// example: go test ./... -v -update +// var update = flag.Bool("update", false, "update golden files") + +const ( + numPeers = 100 + devGroupID = "group-dev" + opsGroupID = "group-ops" + allGroupID = "group-all" + routeID = route.ID("route-main") + routeHA1ID = route.ID("route-ha-1") + routeHA2ID = route.ID("route-ha-2") + policyIDDevOps = "policy-dev-ops" + policyIDAll = "policy-all" + policyIDPosture = "policy-posture" + policyIDDrop = "policy-drop" + postureCheckID = "posture-check-ver" + networkResourceID = "res-database" + networkID = "net-database" + networkRouterID = "router-database" + nameserverGroupID = "ns-group-main" + testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map. + expiredPeerID = "peer-98" // This peer will be online but with an expired session. + offlinePeerID = "peer-99" // This peer will be completely offline. + routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network. + testAccountID = "account-golden-test" +) + +func TestGetPeerNetworkMap_Golden(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden.json") + + t.Log("Update golden file...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from OLD method does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new.json") + + t.Log("Update golden file...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from NEW builder does not match golden file") +} + +func BenchmarkGetPeerNetworkMap(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + + b.ResetTimer() + b.Run("old builder", func(b *testing.B) { + for range b.N { + for _, peerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + b.ResetTimer() + b.Run("new builder", func(b *testing.B) { + for range b.N { + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + for _, peerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + newPeerID := "peer-new-101" + newPeerIP := net.IP{100, 64, 1, 1} + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: newPeerIP, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newPeerID] = newPeer + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = append(devGroup.Peers, newPeerID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newPeerID) + } + + validatedPeersMap[newPeerID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json") + + t.Log("Update golden file with new peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerAdded(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + newPeerID := "peer-new-101" + newPeerIP := net.IP{100, 64, 1, 1} + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: newPeerIP, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newPeerID] = newPeer + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = append(devGroup.Peers, newPeerID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newPeerID) + } + + validatedPeersMap[newPeerID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerAddedIncremental(newPeerID) + require.NoError(t, err, "error adding peer to cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json") + t.Log("Update golden file with OnPeerAdded...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + newPeerID := "peer-new-101" + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: net.IP{100, 64, 1, 1}, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + } + + account.Peers[newPeerID] = newPeer + account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID) + account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID) + validatedPeersMap[newPeerID] = struct{}{} + + b.ResetTimer() + b.Run("old builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + + b.ResetTimer() + b.Run("new builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerAddedIncremental(newPeerID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json") + + t.Log("Update golden file with new router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new router does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerAddedIncremental(newRouterID) + require.NoError(t, err, "error adding router to cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") + + t.Log("Update golden file with OnPeerAdded router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + b.ResetTimer() + b.Run("old builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + + b.ResetTimer() + b.Run("new builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerAddedIncremental(newRouterID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + deletedPeerID := "peer-25" // peer from devs group + + delete(account.Peers, deletedPeerID) + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + delete(validatedPeersMap, deletedPeerID) + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json") + + t.Log("Update golden file with deleted peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerDeleted(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + deletedPeerID := "peer-25" // devs group peer + + delete(account.Peers, deletedPeerID) + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + delete(validatedPeersMap, deletedPeerID) + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerDeleted(deletedPeerID) + require.NoError(t, err, "error deleting peer from cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json") + t.Log("Update golden file with OnPeerDeleted...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerDeleted does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + deletedRouterID := "peer-75" // router peer + + var affectedRoute *route.Route + for _, r := range account.Routes { + if r.PeerID == deletedRouterID { + affectedRoute = r + break + } + } + require.NotNil(t, affectedRoute, "Router peer should have a route") + + for _, group := range account.Groups { + group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { + return id == deletedRouterID + }) + } + + for routeID, r := range account.Routes { + if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { + delete(account.Routes, routeID) + } + } + delete(account.Peers, deletedRouterID) + delete(validatedPeersMap, deletedRouterID) + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json") + + t.Log("Update golden file with deleted peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithDeletedRouterPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + deletedRouterID := "peer-75" // router peer + + var affectedRoute *route.Route + for _, r := range account.Routes { + if r.PeerID == deletedRouterID { + affectedRoute = r + break + } + } + require.NotNil(t, affectedRoute, "Router peer should have a route") + + for _, group := range account.Groups { + group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { + return id == deletedRouterID + }) + } + for routeID, r := range account.Routes { + if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { + delete(account.Routes, routeID) + } + } + delete(account.Peers, deletedRouterID) + delete(validatedPeersMap, deletedRouterID) + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerDeleted(deletedRouterID) + require.NoError(t, err, "error deleting routing peer from cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err) + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json") + + t.Log("Update golden file with deleted router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err) + + require.JSONEq(t, string(expectedJSON), string(jsonData), + "network map after deleting router does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + + deletedPeerID := "peer-25" + + delete(account.Peers, deletedPeerID) + account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool { + return id == deletedPeerID + }) + account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool { + return id == deletedPeerID + }) + delete(validatedPeersMap, deletedPeerID) + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + b.ResetTimer() + b.Run("old builder after delete", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + + b.ResetTimer() + b.Run("new builder after delete", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerDeleted(deletedPeerID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) { + for _, peer := range networkMap.Peers { + if peer.Status != nil { + peer.Status.LastSeen = time.Time{} + } + peer.LastLogin = &time.Time{} + } + for _, peer := range networkMap.OfflinePeers { + if peer.Status != nil { + peer.Status.LastSeen = time.Time{} + } + peer.LastLogin = &time.Time{} + } + + sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID }) + sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID }) + sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID }) + + sort.Slice(networkMap.FirewallRules, func(i, j int) bool { + r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j] + if r1.PeerIP != r2.PeerIP { + return r1.PeerIP < r2.PeerIP + } + if r1.Protocol != r2.Protocol { + return r1.Protocol < r2.Protocol + } + if r1.Direction != r2.Direction { + return r1.Direction < r2.Direction + } + if r1.Action != r2.Action { + return r1.Action < r2.Action + } + return r1.Port < r2.Port + }) + + sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool { + r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j] + if r1.RouteID != r2.RouteID { + return r1.RouteID < r2.RouteID + } + if r1.Action != r2.Action { + return r1.Action < r2.Action + } + if r1.Destination != r2.Destination { + return r1.Destination < r2.Destination + } + if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 { + if r1.SourceRanges[0] != r2.SourceRanges[0] { + return r1.SourceRanges[0] < r2.SourceRanges[0] + } + } + return r1.Port < r2.Port + }) + + for _, ranges := range networkMap.RoutesFirewallRules { + sort.Slice(ranges.SourceRanges, func(i, j int) bool { + return ranges.SourceRanges[i] < ranges.SourceRanges[j] + }) + } +} + +func createTestAccountWithEntities() *types.Account { + peers := make(map[string]*nbpeer.Peer) + devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} + + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + ip := net.IP{100, 64, 0, byte(i + 1)} + wtVersion := "0.25.0" + if i%2 == 0 { + wtVersion = "0.40.0" + } + + p := &nbpeer.Peer{ + ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, + } + + if peerID == expiredPeerID { + p.LoginExpirationEnabled = true + pastTimestamp := time.Now().Add(-2 * time.Hour) + p.LastLogin = &pastTimestamp + } + + peers[peerID] = p + allGroupPeers = append(allGroupPeers, peerID) + if i < numPeers/2 { + devGroupPeers = append(devGroupPeers, peerID) + } else { + opsGroupPeers = append(opsGroupPeers, peerID) + } + + } + + groups := map[string]*types.Group{ + allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, + devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, + opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, + } + + policies := []*types.Policy{ + { + ID: policyIDAll, Name: "Default-Allow", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{allGroupID}, Destinations: []string{allGroupID}, + }}, + }, + { + ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false, + PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}}, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop, + Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, + SourcePostureChecks: []string{postureCheckID}, + Rules: []*types.PolicyRule{{ + ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID}, + }}, + }, + } + + routes := map[route.ID]*route.Route{ + routeID: { + ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), + Peer: peers["peer-75"].Key, + PeerID: "peer-75", + Description: "Route to internal resource", Enabled: true, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + }, + routeHA1ID: { + ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-80"].Key, + PeerID: "peer-80", + Description: "HA Route 1", Enabled: true, Metric: 1000, + PeerGroups: []string{allGroupID}, + Groups: []string{allGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + routeHA2ID: { + ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-90"].Key, + PeerID: "peer-90", + Description: "HA Route 2", Enabled: true, Metric: 900, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + } + + account := &types.Account{ + Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Network: &types.Network{ + Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, + }, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, + NameServerGroups: map[string]*dns.NameServerGroup{ + nameserverGroupID: { + ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, + NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{ + {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, + }}, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, + }, + Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, + }, + Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + return account +} diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go new file mode 100644 index 000000000..58f1bfa30 --- /dev/null +++ b/management/server/types/networkmapbuilder.go @@ -0,0 +1,1932 @@ +package types + +import ( + "context" + "fmt" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" +) + +const ( + allPeers = "0.0.0.0" + fw = "fw:" + rfw = "route-fw:" + nr = "network-resource-" +) + +type NetworkMapCache struct { + globalRoutes map[route.ID]*route.Route + globalRules map[string]*FirewallRule //ruleId + globalRouteRules map[string]*RouteFirewallRule //ruleId + globalPeers map[string]*nbpeer.Peer + + groupToPeers map[string][]string + peerToGroups map[string][]string + policyToRules map[string][]*PolicyRule //policyId + groupToPolicies map[string][]*Policy + groupToRoutes map[string][]*route.Route + peerToRoutes map[string][]*route.Route + + peerACLs map[string]*PeerACLView + peerRoutes map[string]*PeerRoutesView + peerDNS map[string]*nbdns.Config + + resourceRouters map[string]map[string]*routerTypes.NetworkRouter + resourcePolicies map[string][]*Policy + + globalResources map[string]*resourceTypes.NetworkResource // resourceId + + acgToRoutes map[string]map[route.ID]*RouteOwnerInfo // routeID -> owner info + noACGRoutes map[route.ID]*RouteOwnerInfo + + mu sync.RWMutex +} + +type RouteOwnerInfo struct { + PeerID string + RouteID route.ID +} + +type PeerACLView struct { + ConnectedPeerIDs []string + FirewallRuleIDs []string +} + +type PeerRoutesView struct { + OwnRouteIDs []route.ID + NetworkResourceIDs []route.ID + InheritedRouteIDs []route.ID + RouteFirewallRuleIDs []string +} + +type NetworkMapBuilder struct { + account atomic.Pointer[Account] + cache *NetworkMapCache + validatedPeers map[string]struct{} +} + +func NewNetworkMapBuilder(account *Account, validatedPeers map[string]struct{}) *NetworkMapBuilder { + builder := &NetworkMapBuilder{ + cache: &NetworkMapCache{ + globalRoutes: make(map[route.ID]*route.Route), + globalRules: make(map[string]*FirewallRule), + globalRouteRules: make(map[string]*RouteFirewallRule), + globalPeers: make(map[string]*nbpeer.Peer), + groupToPeers: make(map[string][]string), + peerToGroups: make(map[string][]string), + policyToRules: make(map[string][]*PolicyRule), + groupToPolicies: make(map[string][]*Policy), + groupToRoutes: make(map[string][]*route.Route), + peerToRoutes: make(map[string][]*route.Route), + peerACLs: make(map[string]*PeerACLView), + peerRoutes: make(map[string]*PeerRoutesView), + peerDNS: make(map[string]*nbdns.Config), + globalResources: make(map[string]*resourceTypes.NetworkResource), + acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo), + noACGRoutes: make(map[route.ID]*RouteOwnerInfo), + }, + validatedPeers: make(map[string]struct{}), + } + builder.account.Store(account) + maps.Copy(builder.validatedPeers, validatedPeers) + + builder.initialBuild(account) + + return builder +} + +func (b *NetworkMapBuilder) initialBuild(account *Account) { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + start := time.Now() + + b.buildGlobalIndexes(account) + + resourceRouters := account.GetResourceRoutersMap() + resourcePolicies := account.GetResourcePoliciesMap() + b.cache.resourceRouters = resourceRouters + b.cache.resourcePolicies = resourcePolicies + + for peerID := range account.Peers { + b.buildPeerACLView(account, peerID) + b.buildPeerRoutesView(account, peerID) + b.buildPeerDNSView(account, peerID) + } + + log.Debugf("NetworkMapBuilder: Initial build completed in %v for account %s", time.Since(start), account.Id) +} + +func (b *NetworkMapBuilder) buildGlobalIndexes(account *Account) { + clear(b.cache.globalPeers) + clear(b.cache.groupToPeers) + clear(b.cache.peerToGroups) + clear(b.cache.policyToRules) + clear(b.cache.groupToPolicies) + clear(b.cache.globalRoutes) + clear(b.cache.globalRules) + clear(b.cache.globalRouteRules) + clear(b.cache.globalResources) + clear(b.cache.groupToRoutes) + clear(b.cache.peerToRoutes) + clear(b.cache.acgToRoutes) + clear(b.cache.noACGRoutes) + + maps.Copy(b.cache.globalPeers, account.Peers) + + for groupID, group := range account.Groups { + peersCopy := make([]string, len(group.Peers)) + copy(peersCopy, group.Peers) + b.cache.groupToPeers[groupID] = peersCopy + + for _, peerID := range group.Peers { + b.cache.peerToGroups[peerID] = append(b.cache.peerToGroups[peerID], groupID) + } + } + + for _, policy := range account.Policies { + if !policy.Enabled { + continue + } + + b.cache.policyToRules[policy.ID] = policy.Rules + + affectedGroups := make(map[string]struct{}) + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + for _, groupID := range rule.Sources { + affectedGroups[groupID] = struct{}{} + } + for _, groupID := range rule.Destinations { + affectedGroups[groupID] = struct{}{} + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + groupId := rule.SourceResource.ID + affectedGroups[groupId] = struct{}{} + b.cache.peerToGroups[rule.SourceResource.ID] = append(b.cache.peerToGroups[rule.SourceResource.ID], groupId) + } + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + groupId := rule.DestinationResource.ID + affectedGroups[groupId] = struct{}{} + b.cache.peerToGroups[rule.DestinationResource.ID] = append(b.cache.peerToGroups[rule.DestinationResource.ID], groupId) + } + } + + for groupID := range affectedGroups { + b.cache.groupToPolicies[groupID] = append(b.cache.groupToPolicies[groupID], policy) + } + } + + for _, resource := range account.NetworkResources { + if !resource.Enabled { + continue + } + b.cache.globalResources[resource.ID] = resource + } + + for _, r := range account.Routes { + if !r.Enabled { + continue + } + for _, groupID := range r.PeerGroups { + b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) + } + if r.Peer != "" { + if peer, ok := b.cache.globalPeers[r.Peer]; ok { + b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) + } + } + } +} + +func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) { + peer := account.GetPeer(peerID) + if peer == nil { + return + } + + allPotentialPeers, firewallRules := b.getPeerConnectionResources(account, peer, b.validatedPeers) + + isRouter, networkResourcesRoutes, sourcePeers := b.getNetworkResourcesForPeer(account, peer) + + var emptyExpiredPeers []*nbpeer.Peer + finalAllPeers := b.addNetworksRoutingPeers( + networkResourcesRoutes, + peer, + allPotentialPeers, + emptyExpiredPeers, + isRouter, + sourcePeers, + ) + + view := &PeerACLView{ + ConnectedPeerIDs: make([]string, 0, len(finalAllPeers)), + FirewallRuleIDs: make([]string, 0, len(firewallRules)), + } + + for _, p := range finalAllPeers { + view.ConnectedPeerIDs = append(view.ConnectedPeerIDs, p.ID) + } + + for _, rule := range firewallRules { + ruleID := b.generateFirewallRuleID(rule) + view.FirewallRuleIDs = append(view.FirewallRuleIDs, ruleID) + b.cache.globalRules[ruleID] = rule + } + + b.cache.peerACLs[peerID] = view +} + +func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer, + validatedPeersMap map[string]struct{}, +) ([]*nbpeer.Peer, []*FirewallRule) { + ctx := context.Background() + + peerID := peer.ID + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + fwRules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + for _, group := range peerGroups { + policies := b.cache.groupToPolicies[group] + for _, policy := range policies { + if isValid := account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID); !isValid { + continue + } + rules := b.cache.policyToRules[policy.ID] + for _, rule := range rules { + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + peerInSources = rule.SourceResource.ID == peerID + } else { + peerInSources = b.isPeerInGroupscached(rule.Sources, peerGroupsMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + peerInDestinations = rule.DestinationResource.ID == peerID + } else { + peerInDestinations = b.isPeerInGroupscached(rule.Destinations, peerGroupsMap) + } + + if !peerInSources && !peerInDestinations { + continue + } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + peer := account.GetPeer(rule.SourceResource.ID) + if peer != nil { + sourcePeers = []*nbpeer.Peer{peer} + } + } else { + sourcePeers = b.getPeersFromGroupscached(account, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + peer := account.GetPeer(rule.DestinationResource.ID) + if peer != nil { + destinationPeers = []*nbpeer.Peer{peer} + } + } else { + destinationPeers = b.getPeersFromGroupscached(account, rule.Destinations, peerID, nil, validatedPeersMap) + } + + if rule.Bidirectional { + if peerInSources { + b.generateResourcescached( + account, rule, destinationPeers, FirewallRuleDirectionIN, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + if peerInDestinations { + b.generateResourcescached( + account, rule, sourcePeers, FirewallRuleDirectionOUT, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + } + + if peerInSources { + b.generateResourcescached( + account, rule, destinationPeers, FirewallRuleDirectionOUT, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + + if peerInDestinations { + b.generateResourcescached( + account, rule, sourcePeers, FirewallRuleDirectionIN, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + } + } + } + + return peers, fwRules +} + +func (b *NetworkMapBuilder) isPeerInGroupscached(groupIDs []string, peerGroupsMap map[string]struct{}) bool { + for _, groupID := range groupIDs { + if _, exists := peerGroupsMap[groupID]; exists { + return true + } + } + return false +} + +func (b *NetworkMapBuilder) getPeersFromGroupscached(account *Account, groupIDs []string, + excludePeerID string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, +) []*nbpeer.Peer { + ctx := context.Background() + uniquePeers := make(map[string]*nbpeer.Peer) + + for _, groupID := range groupIDs { + peerIDs := b.cache.groupToPeers[groupID] + for _, peerID := range peerIDs { + if peerID == excludePeerID { + continue + } + + if _, ok := validatedPeersMap[peerID]; !ok { + continue + } + + peer := b.cache.globalPeers[peerID] + if peer == nil { + continue + } + + if len(postureChecksIDs) > 0 { + if !account.validatePostureChecksOnPeer(ctx, postureChecksIDs, peerID) { + continue + } + } + + uniquePeers[peerID] = peer + } + } + + result := make([]*nbpeer.Peer, 0, len(uniquePeers)) + for _, peer := range uniquePeers { + result = append(result, peer) + } + + return result +} + +func (b *NetworkMapBuilder) generateResourcescached( + account *Account, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer, + peers *[]*nbpeer.Peer, rules *[]*FirewallRule, peersExists map[string]struct{}, rulesExists map[string]struct{}, +) { + isAll := false + if allGroup, err := account.GetGroupAll(); err == nil { + isAll = (len(allGroup.Peers) - 1) == len(groupPeers) + } + + for _, peer := range groupPeers { + if peer == nil { + continue + } + if _, ok := peersExists[peer.ID]; !ok { + *peers = append(*peers, peer) + peersExists[peer.ID] = struct{}{} + } + + fr := FirewallRule{ + PolicyID: rule.ID, + PeerIP: peer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + if isAll { + fr.PeerIP = allPeers + } + + var s strings.Builder + s.WriteString(rule.ID) + s.WriteString(fr.PeerIP) + s.WriteString(strconv.Itoa(direction)) + s.WriteString(fr.Protocol) + s.WriteString(fr.Action) + s.WriteString(strings.Join(rule.Ports, ",")) + + ruleID := s.String() + + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { + *rules = append(*rules, &fr) + continue + } + + *rules = append(*rules, expandPortsAndRanges(fr, rule, targetPeer)...) + } +} + +func (b *NetworkMapBuilder) getNetworkResourcesForPeer(account *Account, peer *nbpeer.Peer) (bool, []*route.Route, map[string]struct{}) { + ctx := context.Background() + peerID := peer.ID + + var isRoutingPeer bool + var routes []*route.Route + allSourcePeers := make(map[string]struct{}) + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + for _, resource := range b.cache.globalResources { + + networkRoutingPeers := b.cache.resourceRouters[resource.NetworkID] + resourcePolicies := b.cache.resourcePolicies[resource.ID] + if len(resourcePolicies) == 0 { + continue + } + + isRouterForThisResource := false + + if networkRoutingPeers != nil { + if router, ok := networkRoutingPeers[peerID]; ok && router.Enabled { + isRoutingPeer = true + isRouterForThisResource = true + if rt := b.createNetworkResourceRoutes(resource, peerID, router, resourcePolicies); rt != nil { + routes = append(routes, rt) + } + } + } + + hasAccessAsClient := false + if !isRouterForThisResource { + for _, policy := range resourcePolicies { + if b.isPeerInGroupscached(policy.SourceGroups(), peerGroupsMap) { + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + hasAccessAsClient = true + break + } + } + } + } + + if hasAccessAsClient && networkRoutingPeers != nil { + for routerPeerID, router := range networkRoutingPeers { + if router.Enabled { + if rt := b.createNetworkResourceRoutes(resource, routerPeerID, router, resourcePolicies); rt != nil { + routes = append(routes, rt) + } + } + } + } + + if isRouterForThisResource { + for _, policy := range resourcePolicies { + var peersWithAccess []*nbpeer.Peer + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peersWithAccess = []*nbpeer.Peer{peer} + } else { + peersWithAccess = b.getPeersFromGroupscached(account, policy.SourceGroups(), "", policy.SourcePostureChecks, b.validatedPeers) + } + for _, p := range peersWithAccess { + allSourcePeers[p.ID] = struct{}{} + } + } + } + } + + return isRoutingPeer, routes, allSourcePeers +} + +func (b *NetworkMapBuilder) createNetworkResourceRoutes( + resource *resourceTypes.NetworkResource, routerPeerID string, + router *routerTypes.NetworkRouter, resourcePolicies []*Policy, +) *route.Route { + if len(resourcePolicies) > 0 { + peer := b.cache.globalPeers[routerPeerID] + if peer != nil { + return resource.ToRoute(peer, router) + } + } + return nil +} + +func (b *NetworkMapBuilder) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers map[string]struct{}, +) []*nbpeer.Peer { + + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) + for _, r := range networkResourcesRoutes { + networkRoutesPeers[r.PeerID] = struct{}{} + } + + delete(sourcePeers, peer.ID) + delete(networkRoutesPeers, peer.ID) + + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) + } + + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) + if isRouter { + for p := range sourcePeers { + missingPeers[p] = struct{}{} + } + } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } + + for p := range missingPeers { + if missingPeer := b.cache.globalPeers[p]; missingPeer != nil { + peersToConnect = append(peersToConnect, missingPeer) + } + } + + return peersToConnect +} + +func (b *NetworkMapBuilder) buildPeerRoutesView(account *Account, peerID string) { + ctx := context.Background() + peer := account.GetPeer(peerID) + if peer == nil { + return + } + resourcePolicies := b.cache.resourcePolicies + + view := &PeerRoutesView{ + OwnRouteIDs: make([]route.ID, 0), + NetworkResourceIDs: make([]route.ID, 0), + RouteFirewallRuleIDs: make([]string, 0), + } + + enabledRoutes, disabledRoutes := b.getRoutingPeerRoutes(peerID) + for _, rt := range enabledRoutes { + if rt.PeerID != "" && rt.PeerID != peerID { + if b.cache.globalPeers[rt.PeerID] == nil { + continue + } + } + + view.OwnRouteIDs = append(view.OwnRouteIDs, rt.ID) + b.cache.globalRoutes[rt.ID] = rt + } + + aclView := b.cache.peerACLs[peerID] + if aclView != nil { + peerRoutesMembership := make(LookupMap) + for _, r := range append(enabledRoutes, disabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(LookupMap) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + for _, aclPeerID := range aclView.ConnectedPeerIDs { + if aclPeerID == peerID { + continue + } + activeRoutes, _ := b.getRoutingPeerRoutes(aclPeerID) + groupFilteredRoutes := account.filterRoutesByGroups(activeRoutes, peerGroupsMap) + haFilteredRoutes := account.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + + for _, inheritedRoute := range haFilteredRoutes { + view.InheritedRouteIDs = append(view.InheritedRouteIDs, inheritedRoute.ID) + b.cache.globalRoutes[inheritedRoute.ID] = inheritedRoute + } + } + } + + _, networkResourcesRoutes, _ := b.getNetworkResourcesForPeer(account, peer) + + for _, rt := range networkResourcesRoutes { + view.NetworkResourceIDs = append(view.NetworkResourceIDs, rt.ID) + b.cache.globalRoutes[rt.ID] = rt + } + + allRoutes := slices.Concat(enabledRoutes, networkResourcesRoutes) + b.updateACGIndexForPeer(peerID, allRoutes) + + routeFirewallRules := b.getPeerRoutesFirewallRules(account, peerID, b.validatedPeers) + for _, rule := range routeFirewallRules { + ruleID := b.generateRouteFirewallRuleID(rule) + view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) + b.cache.globalRouteRules[ruleID] = rule + } + + if len(networkResourcesRoutes) > 0 { + networkResourceFirewallRules := account.GetPeerNetworkResourceFirewallRules(ctx, peer, b.validatedPeers, networkResourcesRoutes, resourcePolicies) + for _, rule := range networkResourceFirewallRules { + ruleID := b.generateRouteFirewallRuleID(rule) + view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) + b.cache.globalRouteRules[ruleID] = rule + } + } + + b.cache.peerRoutes[peerID] = view +} + +func (b *NetworkMapBuilder) updateACGIndexForPeer(peerID string, routes []*route.Route) { + for acg, routeMap := range b.cache.acgToRoutes { + for routeID, info := range routeMap { + if info.PeerID == peerID { + delete(routeMap, routeID) + } + } + if len(routeMap) == 0 { + delete(b.cache.acgToRoutes, acg) + } + } + + for routeID, info := range b.cache.noACGRoutes { + if info.PeerID == peerID { + delete(b.cache.noACGRoutes, routeID) + } + } + + for _, rt := range routes { + if !rt.Enabled { + continue + } + + if len(rt.AccessControlGroups) == 0 { + b.cache.noACGRoutes[rt.ID] = &RouteOwnerInfo{ + PeerID: peerID, + RouteID: rt.ID, + } + } else { + for _, acg := range rt.AccessControlGroups { + if b.cache.acgToRoutes[acg] == nil { + b.cache.acgToRoutes[acg] = make(map[route.ID]*RouteOwnerInfo) + } + + b.cache.acgToRoutes[acg][rt.ID] = &RouteOwnerInfo{ + PeerID: peerID, + RouteID: rt.ID, + } + } + } + } +} + +func (b *NetworkMapBuilder) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + peer := b.cache.globalPeers[peerID] + if peer == nil { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route, id string) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + if r.Enabled { + // maybe here is some mess - here we store peer key (see comment below) + r.Peer = peer.Key + enabledRoutes = append(enabledRoutes, r) + return + } + disabledRoutes = append(disabledRoutes, r) + } + + peerGroups := b.cache.peerToGroups[peerID] + for _, groupID := range peerGroups { + groupRoutes := b.cache.groupToRoutes[groupID] + for _, r := range groupRoutes { + newPeerRoute := r.Copy() + // and here we store peer ID - this logic is taken from original account.getRoutingPeerRoutes + newPeerRoute.Peer = peerID + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + peerID) + takeRoute(newPeerRoute, peerID) + } + } + for _, r := range b.cache.peerToRoutes[peerID] { + takeRoute(r.Copy(), peerID) + } + return enabledRoutes, disabledRoutes +} + +func (b *NetworkMapBuilder) getPeerRoutesFirewallRules(account *Account, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + enabledRoutes, _ := b.getRoutingPeerRoutes(peerID) + for _, route := range enabledRoutes { + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := b.getDistributionGroupsPeers(route) + + for _, accessGroup := range route.AccessControlGroups { + policies := b.getAllRoutePoliciesFromGroups([]string{accessGroup}) + + rules := b.getRouteFirewallRules(peerID, policies, route, validatedPeersMap, distributionPeers, account) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (b *NetworkMapBuilder) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + groupPeers := b.cache.groupToPeers[id] + if groupPeers == nil { + continue + } + + for _, pID := range groupPeers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func (b *NetworkMapBuilder) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy { + routePolicies := make(map[string]*Policy) + + for _, groupID := range accessControlGroups { + candidatePolicies := b.cache.groupToPolicies[groupID] + + for _, policy := range candidatePolicies { + if _, found := routePolicies[policy.ID]; found { + continue + } + policyRules := b.cache.policyToRules[policy.ID] + for _, rule := range policyRules { + if slices.Contains(rule.Destinations, groupID) { + routePolicies[policy.ID] = policy + break + } + } + } + } + + return maps.Values(routePolicies) +} + +func (b *NetworkMapBuilder) getRouteFirewallRules( + peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, + distributionPeers map[string]struct{}, account *Account, +) []*RouteFirewallRule { + ctx := context.Background() + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := b.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap, account) + + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (b *NetworkMapBuilder) getRulePeers( + rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, + validatedPeersMap map[string]struct{}, account *Account, +) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + + for _, id := range rule.Sources { + groupPeers := b.cache.groupToPeers[id] + if groupPeers == nil { + continue + } + + for _, pID := range groupPeers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + + if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := b.cache.globalPeers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (b *NetworkMapBuilder) buildPeerDNSView(account *Account, peerID string) { + peerGroups := b.cache.peerToGroups[peerID] + checkGroups := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + checkGroups[groupID] = struct{}{} + } + + dnsManagementStatus := b.getPeerDNSManagementStatus(account, checkGroups) + dnsConfig := &nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + dnsConfig.NameServerGroups = b.getPeerNSGroups(account, peerID, checkGroups) + } + + b.cache.peerDNS[peerID] = dnsConfig +} + +func (b *NetworkMapBuilder) getPeerDNSManagementStatus(account *Account, checkGroups map[string]struct{}) bool { + + enabled := true + for _, groupID := range account.DNSSettings.DisabledManagementGroups { + _, found := checkGroups[groupID] + if found { + enabled = false + break + } + } + return enabled +} + +func (b *NetworkMapBuilder) getPeerNSGroups(account *Account, peerID string, checkGroups map[string]struct{}) []*nbdns.NameServerGroup { + var peerNSGroups []*nbdns.NameServerGroup + + for _, nsGroup := range account.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + _, found := checkGroups[gID] + if found { + peer := b.cache.globalPeers[peerID] + if !peerIsNameserver(peer, nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + break + } + } + } + } + + return peerNSGroups +} + +func (b *NetworkMapBuilder) UpdateAccountPointer(account *Account) { + b.account.Store(account) +} + +func (b *NetworkMapBuilder) GetPeerNetworkMap( + ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, + validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + start := time.Now() + account := b.account.Load() + + peer := account.GetPeer(peerID) + if peer == nil { + return &NetworkMap{Network: account.Network.Copy()} + } + + b.cache.mu.RLock() + defer b.cache.mu.RUnlock() + + aclView := b.cache.peerACLs[peerID] + routesView := b.cache.peerRoutes[peerID] + dnsConfig := b.cache.peerDNS[peerID] + + if aclView == nil || routesView == nil || dnsConfig == nil { + return &NetworkMap{Network: account.Network.Copy()} + } + + nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers) + + if metrics != nil { + objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache", + account.Id, objectCount) + } + } + + return nm +} + +func (b *NetworkMapBuilder) assembleNetworkMap( + account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, + dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{}, +) *NetworkMap { + + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + + for _, peerID := range aclView.ConnectedPeerIDs { + if _, ok := validatedPeers[peerID]; !ok { + continue + } + + peer := b.cache.globalPeers[peerID] + if peer == nil { + continue + } + + expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration) + if account.Settings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, peer) + } else { + peersToConnect = append(peersToConnect, peer) + } + } + + var routes []*route.Route + allRouteIDs := slices.Concat(routesView.OwnRouteIDs, routesView.NetworkResourceIDs, routesView.InheritedRouteIDs) + + for _, routeID := range allRouteIDs { + if route := b.cache.globalRoutes[routeID]; route != nil { + routes = append(routes, route) + } + } + + var firewallRules []*FirewallRule + for _, ruleID := range aclView.FirewallRuleIDs { + if rule := b.cache.globalRules[ruleID]; rule != nil { + firewallRules = append(firewallRules, rule) + } + } + + var routesFirewallRules []*RouteFirewallRule + for _, ruleID := range routesView.RouteFirewallRuleIDs { + if rule := b.cache.globalRouteRules[ruleID]; rule != nil { + routesFirewallRules = append(routesFirewallRules, rule) + } + } + + finalDNSConfig := *dnsConfig + if finalDNSConfig.ServiceEnable && customZone.Domain != "" { + var zones []nbdns.CustomZone + records := filterZoneRecordsForPeers(peer, customZone, peersToConnect, expiredPeers) + zones = append(zones, nbdns.CustomZone{ + Domain: customZone.Domain, + Records: records, + }) + finalDNSConfig.CustomZones = zones + } + + return &NetworkMap{ + Peers: peersToConnect, + Network: account.Network.Copy(), + Routes: routes, + DNSConfig: finalDNSConfig, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: routesFirewallRules, + } +} + +func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string { + var s strings.Builder + s.WriteString(fw) + s.WriteString(rule.PolicyID) + s.WriteRune(':') + s.WriteString(rule.PeerIP) + s.WriteRune(':') + s.WriteString(strconv.Itoa(rule.Direction)) + s.WriteRune(':') + s.WriteString(rule.Protocol) + s.WriteRune(':') + s.WriteString(rule.Action) + s.WriteRune(':') + s.WriteString(rule.Port) + s.WriteRune(':') + s.WriteString(strconv.Itoa(int(rule.PortRange.Start))) + s.WriteRune('-') + s.WriteString(strconv.Itoa(int(rule.PortRange.End))) + return s.String() +} + +func (b *NetworkMapBuilder) generateRouteFirewallRuleID(rule *RouteFirewallRule) string { + var s strings.Builder + s.WriteString(rfw) + s.WriteString(string(rule.RouteID)) + s.WriteRune(':') + s.WriteString(rule.Destination) + s.WriteRune(':') + s.WriteString(rule.Action) + s.WriteRune(':') + s.WriteString(strings.Join(rule.SourceRanges, ",")) + s.WriteRune(':') + s.WriteString(rule.Protocol) + s.WriteRune(':') + s.WriteString(strconv.Itoa(int(rule.Port))) + return s.String() +} + +func (b *NetworkMapBuilder) isPeerInGroups(groupIDs []string, peerGroups []string) bool { + for _, groupID := range groupIDs { + if slices.Contains(peerGroups, groupID) { + return true + } + } + return false +} + +func (b *NetworkMapBuilder) isPeerRouter(account *Account, peerID string) bool { + for _, r := range account.Routes { + if !r.Enabled { + continue + } + + if r.PeerID == peerID { + return true + } + + if peer := b.cache.globalPeers[peerID]; peer != nil { + if r.Peer == peer.Key && r.PeerID == "" { + return true + } + } + } + + routers := account.GetResourceRoutersMap() + for _, networkRouters := range routers { + if router, exists := networkRouters[peerID]; exists && router.Enabled { + return true + } + } + + return false +} + +type ViewDelta struct { + AddedPeerIDs []string + RemovedPeerIDs []string + AddedRuleIDs []string + RemovedRuleIDs []string +} + +func (b *NetworkMapBuilder) OnPeerAddedIncremental(peerID string) error { + tt := time.Now() + account := b.account.Load() + peer := account.GetPeer(peerID) + if peer == nil { + return fmt.Errorf("peer %s not found in account", peerID) + } + + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + log.Debugf("NetworkMapBuilder: Adding peer %s (IP: %s) to cache", peerID, peer.IP.String()) + + b.validatedPeers[peerID] = struct{}{} + + b.cache.globalPeers[peerID] = peer + + peerGroups := b.updateIndexesForNewPeer(account, peerID) + + b.buildPeerACLView(account, peerID) + b.buildPeerRoutesView(account, peerID) + b.buildPeerDNSView(account, peerID) + + log.Debugf("NetworkMapBuilder: Adding peer %s to cache, views took %s", peerID, time.Since(tt)) + + b.incrementalUpdateAffectedPeers(account, peerID, peerGroups) + + log.Debugf("NetworkMapBuilder: Added peer %s to cache, took %s", peerID, time.Since(tt)) + + return nil +} + +func (b *NetworkMapBuilder) updateIndexesForNewPeer(account *Account, peerID string) []string { + peerGroups := make([]string, 0) + + for groupID, group := range account.Groups { + if slices.Contains(group.Peers, peerID) { + if !slices.Contains(b.cache.groupToPeers[groupID], peerID) { + b.cache.groupToPeers[groupID] = append(b.cache.groupToPeers[groupID], peerID) + } + peerGroups = append(peerGroups, groupID) + } + } + + b.cache.peerToGroups[peerID] = peerGroups + + for _, r := range account.Routes { + if !r.Enabled || b.cache.globalRoutes[r.ID] != nil { + continue + } + for _, groupID := range r.PeerGroups { + if !slices.Contains(b.cache.groupToRoutes[groupID], r) { + b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) + } + } + if r.Peer != "" { + if peer, ok := b.cache.globalPeers[r.Peer]; ok { + if !slices.Contains(b.cache.peerToRoutes[peer.ID], r) { + b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) + } + } + } + b.cache.globalRoutes[r.ID] = r + } + + return peerGroups +} + +func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, newPeerID string, peerGroups []string) { + updates := b.calculateIncrementalUpdates(account, newPeerID, peerGroups) + + if b.isPeerRouter(account, newPeerID) { + affectedByRoutes := b.findPeersAffectedByNewRouter(account, newPeerID, peerGroups) + for affectedPeerID := range affectedByRoutes { + if affectedPeerID == newPeerID { + continue + } + if _, exists := updates[affectedPeerID]; !exists { + updates[affectedPeerID] = &PeerUpdateDelta{ + PeerID: affectedPeerID, + RebuildRoutesView: true, + } + } else { + updates[affectedPeerID].RebuildRoutesView = true + } + } + } + + for affectedPeerID, delta := range updates { + b.applyDeltaToPeer(account, affectedPeerID, delta) + } +} + +func (b *NetworkMapBuilder) findPeersAffectedByNewRouter(account *Account, newRouterID string, routerGroups []string) map[string]struct{} { + affected := make(map[string]struct{}) + enabledRoutes, _ := b.getRoutingPeerRoutes(newRouterID) + + for _, route := range enabledRoutes { + for _, distGroupID := range route.Groups { + if peers := b.cache.groupToPeers[distGroupID]; peers != nil { + for _, peerID := range peers { + if peerID != newRouterID { + affected[peerID] = struct{}{} + } + } + } + } + + for _, peerGroupID := range route.PeerGroups { + if peers := b.cache.groupToPeers[peerGroupID]; peers != nil { + for _, peerID := range peers { + if peerID != newRouterID { + affected[peerID] = struct{}{} + } + } + } + } + } + + for _, route := range account.Routes { + if !route.Enabled { + continue + } + + routerInPeerGroups := false + for _, peerGroupID := range route.PeerGroups { + if slices.Contains(routerGroups, peerGroupID) { + routerInPeerGroups = true + break + } + } + + if routerInPeerGroups { + for _, distGroupID := range route.Groups { + if peers := b.cache.groupToPeers[distGroupID]; peers != nil { + for _, peerID := range peers { + affected[peerID] = struct{}{} + } + } + } + } + } + + return affected +} + +func (b *NetworkMapBuilder) calculateIncrementalUpdates(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta { + updates := make(map[string]*PeerUpdateDelta) + ctx := context.Background() + + groupAllLn := 0 + if allGroup, err := account.GetGroupAll(); err == nil { + groupAllLn = len(allGroup.Peers) - 1 + } + + newPeer := b.cache.globalPeers[newPeerID] + if newPeer == nil { + return updates + } + + for _, policy := range account.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + peerInSources := b.isPeerInGroups(rule.Sources, peerGroups) + peerInDestinations := b.isPeerInGroups(rule.Destinations, peerGroups) + + if peerInSources { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + + if peerInDestinations { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + + if rule.Bidirectional { + if peerInSources { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + if peerInDestinations { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + } + } + } + + b.calculateRouteFirewallUpdates(newPeerID, newPeer, peerGroups, updates) + + b.calculateNetworkResourceFirewallUpdates(ctx, account, newPeerID, newPeer, peerGroups, updates) + + b.calculateNewRouterNetworkResourceUpdates(ctx, account, newPeerID, updates) + + return updates +} + +func (b *NetworkMapBuilder) calculateNewRouterNetworkResourceUpdates( + ctx context.Context, account *Account, newPeerID string, + updates map[string]*PeerUpdateDelta, +) { + resourceRouters := b.cache.resourceRouters + + for networkID, routers := range resourceRouters { + router, isRouter := routers[newPeerID] + if !isRouter || !router.Enabled { + continue + } + + for _, resource := range b.cache.globalResources { + if resource.NetworkID != networkID { + continue + } + + policies := b.cache.resourcePolicies[resource.ID] + if len(policies) == 0 { + continue + } + + peersWithAccess := make(map[string]struct{}) + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + sourceGroups := policy.SourceGroups() + for _, sourceGroup := range sourceGroups { + groupPeers := b.cache.groupToPeers[sourceGroup] + for _, peerID := range groupPeers { + if peerID == newPeerID { + continue + } + + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + peersWithAccess[peerID] = struct{}{} + } + } + } + } + + for peerID := range peersWithAccess { + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + } + updates[peerID] = delta + } + + if delta.AddConnectedPeer == "" { + delta.AddConnectedPeer = newPeerID + } + + delta.RebuildRoutesView = true + } + } + } +} + +func (b *NetworkMapBuilder) calculateRouteFirewallUpdates( + newPeerID string, newPeer *nbpeer.Peer, + peerGroups []string, updates map[string]*PeerUpdateDelta, +) { + processedPeerRoutes := make(map[string]map[route.ID]struct{}) + + for routeID, info := range b.cache.noACGRoutes { + if info.PeerID == newPeerID { + continue + } + + b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) + + if processedPeerRoutes[info.PeerID] == nil { + processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) + } + processedPeerRoutes[info.PeerID][routeID] = struct{}{} + } + + for _, acg := range peerGroups { + routeInfos := b.cache.acgToRoutes[acg] + if routeInfos == nil { + continue + } + + for routeID, info := range routeInfos { + if info.PeerID == newPeerID { + continue + } + + if processedRoutes, exists := processedPeerRoutes[info.PeerID]; exists { + if _, processed := processedRoutes[routeID]; processed { + continue + } + } + + b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) + + if processedPeerRoutes[info.PeerID] == nil { + processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) + } + processedPeerRoutes[info.PeerID][routeID] = struct{}{} + } + } +} + +func (b *NetworkMapBuilder) addRouteFirewallUpdate( + updates map[string]*PeerUpdateDelta, peerID string, + routeID string, sourceIP string, +) { + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + UpdateRouteFirewallRules: make([]*RouteFirewallRuleUpdate, 0), + } + updates[peerID] = delta + } + + for _, existing := range delta.UpdateRouteFirewallRules { + if existing.RuleID == routeID && existing.AddSourceIP == sourceIP { + return + } + } + + delta.UpdateRouteFirewallRules = append(delta.UpdateRouteFirewallRules, &RouteFirewallRuleUpdate{ + RuleID: routeID, + AddSourceIP: sourceIP, + }) +} + +func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates( + ctx context.Context, account *Account, newPeerID string, + newPeer *nbpeer.Peer, peerGroups []string, updates map[string]*PeerUpdateDelta, +) { + for _, resource := range b.cache.globalResources { + resourcePolicies := b.cache.resourcePolicies + resourceRouters := b.cache.resourceRouters + + policies := resourcePolicies[resource.ID] + peerHasAccess := false + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + sourceGroups := policy.SourceGroups() + for _, sourceGroup := range sourceGroups { + if slices.Contains(peerGroups, sourceGroup) { + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, newPeerID) { + peerHasAccess = true + break + } + } + } + + if peerHasAccess { + break + } + } + + if !peerHasAccess { + continue + } + + networkRouters := resourceRouters[resource.NetworkID] + for routerPeerID, router := range networkRouters { + if !router.Enabled || routerPeerID == newPeerID { + continue + } + + delta := updates[routerPeerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: routerPeerID, + } + updates[routerPeerID] = delta + } + + if delta.AddConnectedPeer == "" { + delta.AddConnectedPeer = newPeerID + } + + delta.RebuildRoutesView = true + } + } +} + +type PeerUpdateDelta struct { + PeerID string + AddConnectedPeer string + AddFirewallRules []*FirewallRuleDelta + AddRoutes []route.ID + UpdateRouteFirewallRules []*RouteFirewallRuleUpdate + UpdateDNS bool + RebuildRoutesView bool +} +type FirewallRuleDelta struct { + Rule *FirewallRule + RuleID string + Direction int +} + +type RouteFirewallRuleUpdate struct { + RuleID string + AddSourceIP string +} + +func (b *NetworkMapBuilder) addUpdateForPeersInGroups( + updates map[string]*PeerUpdateDelta, groupIDs []string, newPeerID string, + rule *PolicyRule, direction int, allGroupLn int, +) { + for _, groupID := range groupIDs { + peers := b.cache.groupToPeers[groupID] + cnt := 0 + for _, peerID := range peers { + if peerID == newPeerID { + continue + } + if _, ok := b.validatedPeers[peerID]; !ok { + continue + } + cnt++ + } + all := false + if allGroupLn > 0 && cnt == allGroupLn { + all = true + } + newPeer := b.cache.globalPeers[newPeerID] + fr := &FirewallRule{ + PolicyID: rule.ID, + PeerIP: newPeer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + for _, peerID := range peers { + if peerID == newPeerID { + continue + } + if _, ok := b.validatedPeers[peerID]; !ok { + continue + } + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + AddConnectedPeer: newPeerID, + AddFirewallRules: make([]*FirewallRuleDelta, 0), + } + updates[peerID] = delta + } + + if all { + fr.PeerIP = allPeers + } + + if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 { + expandedRules := expandPortsAndRanges(*fr, rule, b.cache.globalPeers[peerID]) + for _, expandedRule := range expandedRules { + ruleID := b.generateFirewallRuleID(expandedRule) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: expandedRule, + RuleID: ruleID, + Direction: direction, + }) + } + } else { + ruleID := b.generateFirewallRuleID(fr) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: fr, + RuleID: ruleID, + Direction: direction, + }) + } + } + } +} + +func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) { + if delta.AddConnectedPeer != "" || len(delta.AddFirewallRules) > 0 { + if aclView := b.cache.peerACLs[peerID]; aclView != nil { + if delta.AddConnectedPeer != "" && !slices.Contains(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) { + aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) + } + + for _, ruleDelta := range delta.AddFirewallRules { + b.cache.globalRules[ruleDelta.RuleID] = ruleDelta.Rule + + if !slices.Contains(aclView.FirewallRuleIDs, ruleDelta.RuleID) { + aclView.FirewallRuleIDs = append(aclView.FirewallRuleIDs, ruleDelta.RuleID) + } + } + } + } + + if delta.RebuildRoutesView { + b.buildPeerRoutesView(account, peerID) + } else if len(delta.UpdateRouteFirewallRules) > 0 { + if routesView := b.cache.peerRoutes[peerID]; routesView != nil { + b.updateRouteFirewallRules(routesView, delta.UpdateRouteFirewallRules) + } + } + + if delta.UpdateDNS { + b.buildPeerDNSView(account, peerID) + } +} + +func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView, updates []*RouteFirewallRuleUpdate) { + for _, update := range updates { + for _, ruleID := range routesView.RouteFirewallRuleIDs { + rule := b.cache.globalRouteRules[ruleID] + if rule == nil { + continue + } + + if string(rule.RouteID) == update.RuleID { + sourceIP := update.AddSourceIP + + if strings.Contains(sourceIP, ":") { + sourceIP += "/128" // IPv6 + } else { + sourceIP += "/32" // IPv4 + } + + if !slices.Contains(rule.SourceRanges, sourceIP) { + rule.SourceRanges = append(rule.SourceRanges, sourceIP) + } + break + } + } + } +} + +func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + account := b.account.Load() + + deletedPeer := b.cache.globalPeers[peerID] + if deletedPeer == nil { + return fmt.Errorf("peer %s not found in cache", peerID) + } + + deletedPeerKey := deletedPeer.Key + peerGroups := b.cache.peerToGroups[peerID] + peerIP := deletedPeer.IP.String() + + log.Debugf("NetworkMapBuilder: Deleting peer %s (IP: %s) from cache", peerID, peerIP) + + delete(b.validatedPeers, peerID) + + routesToDelete := []route.ID{} + + for routeID, r := range account.Routes { + if r.Peer != deletedPeerKey && r.PeerID != peerID { + continue + } + if len(r.PeerGroups) == 0 { + routesToDelete = append(routesToDelete, routeID) + continue + } + newPeerAssigned := false + for _, groupID := range r.PeerGroups { + candidatePeerIDs := b.cache.groupToPeers[groupID] + for _, candidatePeerID := range candidatePeerIDs { + if candidatePeerID == peerID { + continue + } + if candidatePeer := b.cache.globalPeers[candidatePeerID]; candidatePeer != nil { + r.Peer = candidatePeer.Key + r.PeerID = candidatePeerID + newPeerAssigned = true + break + } + } + if newPeerAssigned { + break + } + } + + if !newPeerAssigned { + routesToDelete = append(routesToDelete, routeID) + } + } + + for _, routeID := range routesToDelete { + delete(account.Routes, routeID) + } + + delete(b.cache.peerACLs, peerID) + delete(b.cache.peerRoutes, peerID) + delete(b.cache.peerDNS, peerID) + + delete(b.cache.globalPeers, peerID) + + for acg, routeMap := range b.cache.acgToRoutes { + for routeID, info := range routeMap { + if info.PeerID == peerID { + delete(routeMap, routeID) + } + } + if len(routeMap) == 0 { + delete(b.cache.acgToRoutes, acg) + } + } + + for _, groupID := range peerGroups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + b.cache.groupToPeers[groupID] = slices.DeleteFunc(peers, func(id string) bool { + return id == peerID + }) + } + } + delete(b.cache.peerToGroups, peerID) + + affectedPeers := make(map[string]struct{}) + + for _, r := range account.Routes { + for _, groupID := range r.Groups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + for _, p := range peers { + affectedPeers[p] = struct{}{} + } + } + } + + for _, groupID := range r.PeerGroups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + for _, p := range peers { + affectedPeers[p] = struct{}{} + } + } + } + } + + for affectedPeerID := range affectedPeers { + if affectedPeerID == peerID { + continue + } + b.buildPeerRoutesView(account, affectedPeerID) + } + + peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP) + for affectedPeerID, updates := range peerDeletionUpdates { + b.applyDeletionUpdates(affectedPeerID, updates) + } + + b.cleanupUnusedRules() + + log.Debugf("NetworkMapBuilder: Deleted peer %s, affected %d other peers", peerID, len(affectedPeers)) + + return nil +} + +func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL( + deletedPeerID string, + peerIP string, +) map[string]*PeerDeletionUpdate { + + affected := make(map[string]*PeerDeletionUpdate) + + for peerID, aclView := range b.cache.peerACLs { + if peerID == deletedPeerID { + continue + } + + if !slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) { + continue + } + if affected[peerID] == nil { + affected[peerID] = &PeerDeletionUpdate{ + RemovePeerID: deletedPeerID, + PeerIP: peerIP, + } + } + + for _, ruleID := range aclView.FirewallRuleIDs { + if rule := b.cache.globalRules[ruleID]; rule != nil && rule.PeerIP == peerIP { + affected[peerID].RemoveFirewallRuleIDs = append( + affected[peerID].RemoveFirewallRuleIDs, + ruleID, + ) + } + } + } + + return affected +} + +type PeerDeletionUpdate struct { + RemovePeerID string + RemoveFirewallRuleIDs []string + RemoveRouteIDs []route.ID + RemoveFromSourceRanges bool + PeerIP string +} + +func (b *NetworkMapBuilder) applyDeletionUpdates(peerID string, updates *PeerDeletionUpdate) { + if aclView := b.cache.peerACLs[peerID]; aclView != nil { + aclView.ConnectedPeerIDs = slices.DeleteFunc(aclView.ConnectedPeerIDs, func(id string) bool { + return id == updates.RemovePeerID + }) + + if len(updates.RemoveFirewallRuleIDs) > 0 { + aclView.FirewallRuleIDs = slices.DeleteFunc(aclView.FirewallRuleIDs, func(ruleID string) bool { + return slices.Contains(updates.RemoveFirewallRuleIDs, ruleID) + }) + } + } + + if routesView := b.cache.peerRoutes[peerID]; routesView != nil { + if len(updates.RemoveRouteIDs) > 0 { + routesView.NetworkResourceIDs = slices.DeleteFunc(routesView.NetworkResourceIDs, func(routeID route.ID) bool { + return slices.Contains(updates.RemoveRouteIDs, routeID) + }) + } + + if updates.RemoveFromSourceRanges { + b.removeIPFromRouteFirewallRules(routesView, updates.PeerIP) + } + } +} + +func (b *NetworkMapBuilder) removeIPFromRouteFirewallRules(routesView *PeerRoutesView, peerIP string) { + sourceIPv4 := peerIP + "/32" + sourceIPv6 := peerIP + "/128" + + rulesToRemove := []string{} + + for _, ruleID := range routesView.RouteFirewallRuleIDs { + if rule := b.cache.globalRouteRules[ruleID]; rule != nil { + rule.SourceRanges = slices.DeleteFunc(rule.SourceRanges, func(source string) bool { + return source == sourceIPv4 || source == sourceIPv6 || source == peerIP + }) + + if len(rule.SourceRanges) == 0 { + rulesToRemove = append(rulesToRemove, ruleID) + } + } + } + + if len(rulesToRemove) > 0 { + routesView.RouteFirewallRuleIDs = slices.DeleteFunc(routesView.RouteFirewallRuleIDs, func(ruleID string) bool { + return slices.Contains(rulesToRemove, ruleID) + }) + } +} + +func (b *NetworkMapBuilder) cleanupUnusedRules() { + usedFirewallRules := make(map[string]struct{}) + usedRouteRules := make(map[string]struct{}) + usedRoutes := make(map[route.ID]struct{}) + + for _, aclView := range b.cache.peerACLs { + for _, ruleID := range aclView.FirewallRuleIDs { + usedFirewallRules[ruleID] = struct{}{} + } + } + + for _, routesView := range b.cache.peerRoutes { + for _, ruleID := range routesView.RouteFirewallRuleIDs { + usedRouteRules[ruleID] = struct{}{} + } + + for _, routeID := range routesView.OwnRouteIDs { + usedRoutes[routeID] = struct{}{} + } + for _, routeID := range routesView.NetworkResourceIDs { + usedRoutes[routeID] = struct{}{} + } + } + + for ruleID := range b.cache.globalRules { + if _, used := usedFirewallRules[ruleID]; !used { + delete(b.cache.globalRules, ruleID) + } + } + + for ruleID := range b.cache.globalRouteRules { + if _, used := usedRouteRules[ruleID]; !used { + delete(b.cache.globalRouteRules, ruleID) + } + } + + for routeID := range b.cache.globalRoutes { + if _, used := usedRoutes[routeID]; !used { + delete(b.cache.globalRoutes, routeID) + } + } +} + +func (b *NetworkMapBuilder) UpdatePeer(peer *nbpeer.Peer) { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + peerStored, ok := b.cache.globalPeers[peer.ID] + if !ok { + return + } + *peerStored = *peer +} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index da12f1b70..adf64592a 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -7,16 +7,14 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" ) const channelBufferSize = 100 type UpdateMessage struct { - Update *proto.SyncResponse - NetworkMap *types.NetworkMap + Update *proto.SyncResponse } type PeersUpdateManager struct { diff --git a/management/server/user.go b/management/server/user.go index 25c87df9c..66bea314f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -991,6 +991,10 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou peer.UserID, peer.ID, accountID, activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) + + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } } if len(peerIDs) != 0 { diff --git a/route/route.go b/route/route.go index 08a2d37dc..c724e7c7d 100644 --- a/route/route.go +++ b/route/route.go @@ -124,6 +124,7 @@ func (r *Route) EventMeta() map[string]any { func (r *Route) Copy() *Route { route := &Route{ ID: r.ID, + AccountID: r.AccountID, Description: r.Description, NetID: r.NetID, Network: r.Network,