mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
Compare commits
43 Commits
v0.60.8
...
refactor/n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
780890f9e6 | ||
|
|
81429f104c | ||
|
|
479c75f827 | ||
|
|
60e3bf4084 | ||
|
|
1002d45232 | ||
|
|
84decc36c0 | ||
|
|
8875168d2b | ||
|
|
f193f0fa9d | ||
|
|
e0fed79690 | ||
|
|
a00e7654b7 | ||
|
|
43b2d599b8 | ||
|
|
0fdf8138f2 | ||
|
|
5098410e66 | ||
|
|
9eda1ade4a | ||
|
|
274711a37e | ||
|
|
53e24ae7f7 | ||
|
|
fbc02343e9 | ||
|
|
ffed4b38ef | ||
|
|
e926ca34b5 | ||
|
|
5d1c61369d | ||
|
|
fd9e21a5f3 | ||
|
|
841bc7564a | ||
|
|
f20a1b3328 | ||
|
|
2ac0da6cac | ||
|
|
148b8b04b3 | ||
|
|
9a56883ffb | ||
|
|
806be13dd5 | ||
|
|
90557da237 | ||
|
|
1d6209841e | ||
|
|
8f0e5708d5 | ||
|
|
5a9aa55121 | ||
|
|
6082c7cdcb | ||
|
|
06eae13352 | ||
|
|
08fba9876b | ||
|
|
ca85aa9b8f | ||
|
|
0ae2241573 | ||
|
|
050c05164a | ||
|
|
333908d06e | ||
|
|
bc6c5ece6e | ||
|
|
fd7b3ae21c | ||
|
|
abd7a84a46 | ||
|
|
f4b2bed1b9 | ||
|
|
2fb971e88a |
@@ -1,11 +1,18 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/cmd"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/management/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
if err := cmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -53,6 +53,9 @@ const (
|
||||
peerSchedulerRetryInterval = 3 * time.Second
|
||||
emptyUserID = "empty user ID in claims"
|
||||
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
||||
|
||||
envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
|
||||
envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
|
||||
)
|
||||
|
||||
type userLoggedInOnce bool
|
||||
@@ -109,6 +112,11 @@ type DefaultAccountManager struct {
|
||||
loginFilter *loginFilter
|
||||
|
||||
disableDefaultPolicy bool
|
||||
|
||||
holder *types.Holder
|
||||
|
||||
expNewNetworkMap bool
|
||||
expNewNetworkMapAIDs map[string]struct{}
|
||||
}
|
||||
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
@@ -196,6 +204,18 @@ func BuildManager(
|
||||
log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start))
|
||||
}()
|
||||
|
||||
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder))
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err)
|
||||
newNetworkMapBuilder = false
|
||||
}
|
||||
|
||||
ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",")
|
||||
expIDs := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
expIDs[id] = struct{}{}
|
||||
}
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: store,
|
||||
geo: geo,
|
||||
@@ -217,6 +237,10 @@ func BuildManager(
|
||||
permissionsManager: permissionsManager,
|
||||
loginFilter: newLoginFilter(),
|
||||
disableDefaultPolicy: disableDefaultPolicy,
|
||||
holder: types.NewHolder(),
|
||||
|
||||
expNewNetworkMap: newNetworkMapBuilder,
|
||||
expNewNetworkMapAIDs: expIDs,
|
||||
}
|
||||
|
||||
am.startWarmup(ctx)
|
||||
@@ -395,6 +419,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
|
||||
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -1477,6 +1504,10 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
||||
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
|
||||
}
|
||||
@@ -2129,6 +2160,11 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateNetworkMap {
|
||||
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -128,4 +128,5 @@ type Manager interface {
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
SetEphemeralManager(em ephemeral.Manager)
|
||||
AllowSync(string, uint64) bool
|
||||
RecalculateNetworkMapCache(ctx context.Context, accountId string) error
|
||||
}
|
||||
|
||||
@@ -1154,7 +1154,16 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
@@ -1205,7 +1214,16 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
manager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
@@ -1239,7 +1257,16 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
@@ -1288,7 +1315,16 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
@@ -1341,7 +1377,16 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
@@ -1736,6 +1781,7 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Address: "172.12.6.1/24",
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
}
|
||||
err := hasNilField(account)
|
||||
if err != nil {
|
||||
|
||||
@@ -117,6 +117,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -114,6 +114,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -182,6 +185,9 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -250,6 +256,9 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -318,6 +327,9 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -471,6 +483,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -509,6 +524,9 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -537,6 +555,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -575,6 +596,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
@@ -44,6 +46,9 @@ import (
|
||||
const (
|
||||
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
|
||||
envBlockPeers = "NB_BLOCK_SAME_PEERS"
|
||||
envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS"
|
||||
|
||||
defaultSyncLim = 1000
|
||||
)
|
||||
|
||||
// GRPCServer an instance of a Management gRPC API server
|
||||
@@ -63,6 +68,9 @@ type GRPCServer struct {
|
||||
logBlockedPeers bool
|
||||
blockPeersWithSameConfig bool
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
syncSem atomic.Int32
|
||||
syncLim int32
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -96,6 +104,16 @@ func NewServer(
|
||||
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
|
||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||
|
||||
syncLim := int32(defaultSyncLim)
|
||||
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
||||
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
||||
if err != nil {
|
||||
log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim)
|
||||
} else {
|
||||
syncLim = int32(syncLimParsed)
|
||||
}
|
||||
}
|
||||
|
||||
return &GRPCServer{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
@@ -110,6 +128,8 @@ func NewServer(
|
||||
logBlockedPeers: logBlockedPeers,
|
||||
blockPeersWithSameConfig: blockPeersWithSameConfig,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
|
||||
syncLim: syncLim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -151,6 +171,11 @@ func getRealIP(ctx context.Context) net.IP {
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
if s.syncSem.Load() >= s.syncLim {
|
||||
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
||||
}
|
||||
s.syncSem.Add(1)
|
||||
|
||||
reqStart := time.Now()
|
||||
|
||||
ctx := srv.Context()
|
||||
@@ -158,6 +183,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
syncReq := &proto.SyncRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, syncReq)
|
||||
if err != nil {
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
realIP := getRealIP(ctx)
|
||||
@@ -172,6 +198,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
|
||||
}
|
||||
if s.blockPeersWithSameConfig {
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
|
||||
}
|
||||
}
|
||||
@@ -196,8 +223,10 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
|
||||
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
|
||||
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||
s.syncSem.Add(-1)
|
||||
return status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||
}
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -213,12 +242,14 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -237,6 +268,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
|
||||
39
management/server/holder.go
Normal file
39
management/server/holder.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) {
|
||||
a := am.holder.GetAccount(account.Id)
|
||||
if a == nil {
|
||||
am.holder.AddAccount(account)
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache = a.NetworkMapCache
|
||||
if account.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache.UpdateAccountPointer(account)
|
||||
am.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account {
|
||||
return am.holder.GetAccount(accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account {
|
||||
a := am.holder.GetAccount(accountID)
|
||||
if a != nil {
|
||||
return a
|
||||
}
|
||||
account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return account
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) {
|
||||
am.holder.AddAccount(account)
|
||||
}
|
||||
@@ -125,9 +125,10 @@ type MockAccountManager struct {
|
||||
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
|
||||
@@ -986,3 +987,10 @@ func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error {
|
||||
if am.RecalculateNetworkMapCacheFunc != nil {
|
||||
return am.RecalculateNetworkMapCacheFunc(ctx, accountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -83,6 +83,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -134,6 +137,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -177,6 +183,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
80
management/server/networkmap.go
Normal file
80
management/server/networkmap.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
am.enrichAccountFromHolder(account)
|
||||
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
accountId string,
|
||||
peerId string,
|
||||
validatedPeers map[string]struct{},
|
||||
customZone nbdns.CustomZone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *types.NetworkMap {
|
||||
account := am.getAccountFromHolderOrInit(accountId)
|
||||
if account == nil {
|
||||
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
||||
return &types.NetworkMap{
|
||||
Network: &types.Network{},
|
||||
}
|
||||
}
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
am.enrichAccountFromHolder(account)
|
||||
return account.OnPeerAddedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
am.enrichAccountFromHolder(account)
|
||||
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
||||
account := am.getAccountFromHolder(accountId)
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
account.UpdatePeerInNetworkMapCache(peer)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
account.RecalculateNetworkMapCache(validatedPeers)
|
||||
am.updateAccountInHolder(account)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
||||
if am.experimentalNetworkMap(accountId) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
||||
return err
|
||||
}
|
||||
am.recalculateNetworkMapCache(account, validatedPeers)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool {
|
||||
_, ok := am.expNewNetworkMapAIDs[accountId]
|
||||
return am.expNewNetworkMap || ok
|
||||
}
|
||||
@@ -177,6 +177,9 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
event()
|
||||
}
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -157,6 +157,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
event()
|
||||
}
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
|
||||
|
||||
return resource, nil
|
||||
@@ -257,6 +260,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
event()
|
||||
}
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
|
||||
|
||||
return resource, nil
|
||||
@@ -331,6 +337,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
event()
|
||||
}
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -119,6 +119,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
|
||||
|
||||
return router, nil
|
||||
@@ -183,6 +186,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
|
||||
|
||||
return router, nil
|
||||
@@ -217,6 +223,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
|
||||
event()
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -145,6 +145,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
}
|
||||
|
||||
if expired {
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
@@ -321,6 +324,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
}
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
|
||||
if peerLabelChanged || requiresPeerUpdates {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
} else if sshChanged {
|
||||
@@ -381,6 +388,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if userID != activity.SystemInitiator {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
@@ -417,7 +436,13 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if am.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -690,6 +715,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
@@ -776,6 +812,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
}
|
||||
|
||||
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -901,6 +940,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
}
|
||||
|
||||
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -1014,9 +1056,17 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
return peer, emptyMap, nil, nil
|
||||
}
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
account = am.getAccountFromHolderOrInit(accountID)
|
||||
} else {
|
||||
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
@@ -1037,7 +1087,13 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -1167,11 +1223,18 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
|
||||
return
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
account = am.getAccountFromHolderOrInit(accountID)
|
||||
} else {
|
||||
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
@@ -1204,6 +1267,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
@@ -1241,7 +1308,13 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
}
|
||||
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
@@ -1257,7 +1330,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
@@ -1351,7 +1424,13 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
return
|
||||
}
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if am.experimentalNetworkMap(accountId) {
|
||||
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -1368,7 +1447,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
|
||||
|
||||
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||
@@ -1580,7 +1659,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
},
|
||||
},
|
||||
},
|
||||
NetworkMap: &types.NetworkMap{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
peerDeletedEvents = append(peerDeletedEvents, func() {
|
||||
|
||||
@@ -168,6 +168,15 @@ func TestPeer_SessionExpired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
testGetNetworkMapGeneral(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testGetNetworkMapGeneral(t)
|
||||
}
|
||||
|
||||
func testGetNetworkMapGeneral(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1003,7 +1012,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testUpdateAccountPeers(t)
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers(t *testing.T) {
|
||||
testUpdateAccountPeers(t)
|
||||
}
|
||||
|
||||
func testUpdateAccountPeers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
@@ -1043,8 +1061,8 @@ func TestUpdateAccountPeers(t *testing.T) {
|
||||
for _, channel := range peerChannels {
|
||||
update := <-channel
|
||||
assert.Nil(t, update.Update.NetbirdConfig)
|
||||
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
|
||||
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
|
||||
assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
|
||||
assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1548,6 +1566,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_LoginPeer(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
@@ -77,6 +77,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -120,6 +123,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -80,6 +80,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -192,6 +192,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -246,6 +249,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
if oldRouteAffectsPeers || newRouteAffectsPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -289,6 +295,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -857,6 +857,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
}
|
||||
account.NameServerGroupsG = nil
|
||||
|
||||
account.InitOnce()
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -21,7 +22,6 @@ import (
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -87,6 +87,13 @@ type Account struct {
|
||||
NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"`
|
||||
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
|
||||
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
|
||||
|
||||
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
|
||||
nmapInitOnce *sync.Once `gorm:"-"`
|
||||
}
|
||||
|
||||
func (a *Account) InitOnce() {
|
||||
a.nmapInitOnce = &sync.Once{}
|
||||
}
|
||||
|
||||
// this class is used by gorm only
|
||||
@@ -124,109 +131,6 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
|
||||
o.SignupFormPending == onboarding.SignupFormPending
|
||||
}
|
||||
|
||||
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
// from the ACL peers that have distribution groups associated with the peer ID.
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route {
|
||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
peerRoutesMembership := make(LookupMap)
|
||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
||||
}
|
||||
|
||||
groupListMap := a.GetPeerGroups(peerID)
|
||||
for _, peer := range aclPeers {
|
||||
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
|
||||
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
|
||||
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
|
||||
routes = append(routes, filteredRoutes...)
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
|
||||
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
_, found := peerMemberships[string(r.GetHAUniqueID())]
|
||||
if !found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
|
||||
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
for _, groupID := range r.Groups {
|
||||
_, found := groupListMap[groupID]
|
||||
if found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
// If the given is not a routing peer, then the lists are empty.
|
||||
func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
|
||||
|
||||
peer := a.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
seenRoute := make(map[route.ID]struct{})
|
||||
|
||||
takeRoute := func(r *route.Route, id string) {
|
||||
if _, ok := seenRoute[r.ID]; ok {
|
||||
return
|
||||
}
|
||||
seenRoute[r.ID] = struct{}{}
|
||||
|
||||
if r.Enabled {
|
||||
r.Peer = peer.Key
|
||||
enabledRoutes = append(enabledRoutes, r)
|
||||
return
|
||||
}
|
||||
disabledRoutes = append(disabledRoutes, r)
|
||||
}
|
||||
|
||||
for _, r := range a.Routes {
|
||||
for _, groupID := range r.PeerGroups {
|
||||
group := a.GetGroup(groupID)
|
||||
if group == nil {
|
||||
log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
|
||||
continue
|
||||
}
|
||||
for _, id := range group.Peers {
|
||||
if id != peerID {
|
||||
continue
|
||||
}
|
||||
|
||||
newPeerRoute := r.Copy()
|
||||
newPeerRoute.Peer = id
|
||||
newPeerRoute.PeerGroups = nil
|
||||
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
|
||||
takeRoute(newPeerRoute, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
if r.Peer == peerID {
|
||||
takeRoute(r.Copy(), peerID)
|
||||
}
|
||||
}
|
||||
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
|
||||
var routes []*route.Route
|
||||
@@ -246,174 +150,6 @@ func (a *Account) GetGroup(groupID string) *Group {
|
||||
return a.Groups[groupID]
|
||||
}
|
||||
|
||||
// GetPeerNetworkMap returns the networkmap for the given peer ID.
|
||||
func (a *Account) GetPeerNetworkMap(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
|
||||
peer := a.Peers[peerID]
|
||||
if peer == nil {
|
||||
return &NetworkMap{
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peerID]; !ok {
|
||||
return &NetworkMap{
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
var expiredPeers []*nbpeer.Peer
|
||||
for _, p := range aclPeers {
|
||||
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
|
||||
if a.Settings.PeerLoginExpirationEnabled && expired {
|
||||
expiredPeers = append(expiredPeers, p)
|
||||
continue
|
||||
}
|
||||
peersToConnect = append(peersToConnect, p)
|
||||
}
|
||||
|
||||
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect)
|
||||
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
|
||||
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
|
||||
var networkResourcesFirewallRules []*RouteFirewallRule
|
||||
if isRouter {
|
||||
networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies)
|
||||
}
|
||||
peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers)
|
||||
|
||||
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: dnsManagementStatus,
|
||||
}
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: peersCustomZone.Domain,
|
||||
Records: records,
|
||||
})
|
||||
}
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
}
|
||||
|
||||
nm := &NetworkMap{
|
||||
Peers: peersToConnectIncludingRouters,
|
||||
Network: a.Network.Copy(),
|
||||
Routes: slices.Concat(networkResourcesRoutes, routesUpdate),
|
||||
DNSConfig: dnsUpdate,
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
||||
}
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules))
|
||||
metrics.CountNetworkMapObjects(objectCount)
|
||||
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
if objectCount > 5000 {
|
||||
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
|
||||
"peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d",
|
||||
a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules))
|
||||
}
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
func (a *Account) addNetworksRoutingPeers(
|
||||
networkResourcesRoutes []*route.Route,
|
||||
peer *nbpeer.Peer,
|
||||
peersToConnect []*nbpeer.Peer,
|
||||
expiredPeers []*nbpeer.Peer,
|
||||
isRouter bool,
|
||||
sourcePeers map[string]struct{},
|
||||
) []*nbpeer.Peer {
|
||||
|
||||
networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
|
||||
for _, r := range networkResourcesRoutes {
|
||||
networkRoutesPeers[r.PeerID] = struct{}{}
|
||||
}
|
||||
|
||||
delete(sourcePeers, peer.ID)
|
||||
delete(networkRoutesPeers, peer.ID)
|
||||
|
||||
for _, existingPeer := range peersToConnect {
|
||||
delete(sourcePeers, existingPeer.ID)
|
||||
delete(networkRoutesPeers, existingPeer.ID)
|
||||
}
|
||||
for _, expPeer := range expiredPeers {
|
||||
delete(sourcePeers, expPeer.ID)
|
||||
delete(networkRoutesPeers, expPeer.ID)
|
||||
}
|
||||
|
||||
missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
|
||||
if isRouter {
|
||||
for p := range sourcePeers {
|
||||
missingPeers[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
for p := range networkRoutesPeers {
|
||||
missingPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
for p := range missingPeers {
|
||||
if missingPeer := a.Peers[p]; missingPeer != nil {
|
||||
peersToConnect = append(peersToConnect, missingPeer)
|
||||
}
|
||||
}
|
||||
|
||||
return peersToConnect
|
||||
}
|
||||
|
||||
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||
groupList := account.GetPeerGroups(peerID)
|
||||
|
||||
var peerNSGroups []*nbdns.NameServerGroup
|
||||
|
||||
for _, nsGroup := range account.NameServerGroups {
|
||||
if !nsGroup.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, gID := range nsGroup.Groups {
|
||||
_, found := groupList[gID]
|
||||
if found {
|
||||
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
|
||||
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peerNSGroups
|
||||
}
|
||||
|
||||
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
|
||||
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if peer.IP.Equal(ns.IP.AsSlice()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) {
|
||||
for _, peer := range account.Peers {
|
||||
label, err := GetPeerHostLabel(peer.Name, peerLabels)
|
||||
@@ -760,32 +496,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
|
||||
return grps
|
||||
}
|
||||
|
||||
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
enabled := true
|
||||
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
|
||||
_, found := peerGroups[groupID]
|
||||
if found {
|
||||
enabled = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerGroups(peerID string) LookupMap {
|
||||
groupList := make(LookupMap)
|
||||
for groupID, group := range a.Groups {
|
||||
for _, id := range group.Peers {
|
||||
if id == peerID {
|
||||
groupList[groupID] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return groupList
|
||||
}
|
||||
|
||||
func (a *Account) GetTakenIPs() []net.IP {
|
||||
var takenIps []net.IP
|
||||
for _, existingPeer := range a.Peers {
|
||||
@@ -890,6 +600,7 @@ func (a *Account) Copy() *Account {
|
||||
NetworkRouters: networkRouters,
|
||||
NetworkResources: networkResources,
|
||||
Onboarding: a.Onboarding,
|
||||
NetworkMapCache: a.NetworkMapCache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -985,192 +696,6 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
|
||||
return groupUpdates
|
||||
}
|
||||
|
||||
// GetPeerConnectionResources for a given peer
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
||||
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
var sourcePeers, destinationPeers []*nbpeer.Peer
|
||||
var peerInSources, peerInDestinations bool
|
||||
|
||||
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
sourcePeers, peerInSources = a.getPeerFromResource(rule.SourceResource, peer.ID)
|
||||
} else {
|
||||
sourcePeers, peerInSources = a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
}
|
||||
|
||||
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
destinationPeers, peerInDestinations = a.getPeerFromResource(rule.DestinationResource, peer.ID)
|
||||
} else {
|
||||
destinationPeers, peerInDestinations = a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap)
|
||||
}
|
||||
|
||||
if rule.Bidirectional {
|
||||
if peerInSources {
|
||||
generateResources(rule, destinationPeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionOUT)
|
||||
}
|
||||
}
|
||||
|
||||
if peerInSources {
|
||||
generateResources(rule, destinationPeers, FirewallRuleDirectionOUT)
|
||||
}
|
||||
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return getAccumulatedResources()
|
||||
}
|
||||
|
||||
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
||||
//
|
||||
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
|
||||
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
|
||||
// generated. The accumulator function returns the result of all the generator calls.
|
||||
func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
|
||||
rulesExists := make(map[string]struct{})
|
||||
peersExists := make(map[string]struct{})
|
||||
rules := make([]*FirewallRule, 0)
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
|
||||
all, err := a.GetGroupAll()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
|
||||
all = &Group{}
|
||||
}
|
||||
|
||||
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
|
||||
isAll := (len(all.Peers) - 1) == len(groupPeers)
|
||||
for _, peer := range groupPeers {
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := peersExists[peer.ID]; !ok {
|
||||
peers = append(peers, peer)
|
||||
peersExists[peer.ID] = struct{}{}
|
||||
}
|
||||
|
||||
fr := FirewallRule{
|
||||
PolicyID: rule.ID,
|
||||
PeerIP: peer.IP.String(),
|
||||
Direction: direction,
|
||||
Action: string(rule.Action),
|
||||
Protocol: string(rule.Protocol),
|
||||
}
|
||||
|
||||
if isAll {
|
||||
fr.PeerIP = "0.0.0.0"
|
||||
}
|
||||
|
||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
|
||||
if _, ok := rulesExists[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
rulesExists[ruleID] = struct{}{}
|
||||
|
||||
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
|
||||
rules = append(rules, &fr)
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
|
||||
}
|
||||
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
return peers, rules
|
||||
}
|
||||
}
|
||||
|
||||
// getAllPeersFromGroups for given peer ID and list of groups
|
||||
//
|
||||
// Returns a list of peers from specified groups that pass specified posture checks
|
||||
// and a boolean indicating if the supplied peer ID exists within these groups.
|
||||
//
|
||||
// Important: Posture checks are applicable only to source group peers,
|
||||
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
||||
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
peerInGroups := false
|
||||
uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups)
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
|
||||
for _, p := range uniquePeerIDs {
|
||||
peer, ok := a.Peers[p]
|
||||
if !ok || peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// validate the peer based on policy posture checks applied
|
||||
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
if !isValid {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peer.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if peer.ID == peerID {
|
||||
peerInGroups = true
|
||||
continue
|
||||
}
|
||||
|
||||
filteredPeers = append(filteredPeers, peer)
|
||||
}
|
||||
|
||||
return filteredPeers, peerInGroups
|
||||
}
|
||||
|
||||
func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) {
|
||||
peer := a.GetPeer(resource.ID)
|
||||
if peer == nil {
|
||||
return []*nbpeer.Peer{}, false
|
||||
}
|
||||
|
||||
return []*nbpeer.Peer{peer}, resource.ID == peerID
|
||||
}
|
||||
|
||||
// validatePostureChecksOnPeer validates the posture checks on a peer
|
||||
func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool {
|
||||
peer, ok := a.Peers[peerID]
|
||||
if !ok && peer == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, postureChecksID := range sourcePostureChecksID {
|
||||
postureChecks := a.GetPostureChecks(postureChecksID)
|
||||
if postureChecks == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, check := range postureChecks.GetChecks() {
|
||||
isValid, err := check.Check(ctx, *peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
|
||||
}
|
||||
if !isValid {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks {
|
||||
for _, postureChecks := range a.PostureChecks {
|
||||
if postureChecks.ID == postureChecksID {
|
||||
@@ -1180,174 +705,6 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
|
||||
func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
|
||||
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
|
||||
|
||||
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
for _, route := range enabledRoutes {
|
||||
// If no access control groups are specified, accept all traffic.
|
||||
if len(route.AccessControlGroups) == 0 {
|
||||
defaultPermit := getDefaultPermit(route)
|
||||
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
|
||||
continue
|
||||
}
|
||||
|
||||
distributionPeers := a.getDistributionGroupsPeers(route)
|
||||
|
||||
for _, accessGroup := range route.AccessControlGroups {
|
||||
policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup})
|
||||
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
|
||||
routesFirewallRules = append(routesFirewallRules, rules...)
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
}
|
||||
|
||||
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
|
||||
var fwRules []*RouteFirewallRule
|
||||
for _, policy := range policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap)
|
||||
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
|
||||
fwRules = append(fwRules, rules...)
|
||||
}
|
||||
}
|
||||
return fwRules
|
||||
}
|
||||
|
||||
func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
|
||||
distPeersWithPolicy := make(map[string]struct{})
|
||||
for _, id := range rule.Sources {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
if pID == peerID {
|
||||
continue
|
||||
}
|
||||
_, distPeer := distributionPeers[pID]
|
||||
_, valid := validatedPeersMap[pID]
|
||||
if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) {
|
||||
distPeersWithPolicy[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
|
||||
for pID := range distPeersWithPolicy {
|
||||
peer := a.Peers[pID]
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
distributionGroupPeers = append(distributionGroupPeers, peer)
|
||||
}
|
||||
return distributionGroupPeers
|
||||
}
|
||||
|
||||
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
|
||||
distPeers := make(map[string]struct{})
|
||||
for _, id := range route.Groups {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
distPeers[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return distPeers
|
||||
}
|
||||
|
||||
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
|
||||
var rules []*RouteFirewallRule
|
||||
|
||||
sources := []string{"0.0.0.0/0"}
|
||||
if route.Network.Addr().Is6() {
|
||||
sources = []string{"::/0"}
|
||||
}
|
||||
rule := RouteFirewallRule{
|
||||
SourceRanges: sources,
|
||||
Action: string(PolicyTrafficActionAccept),
|
||||
Destination: route.Network.String(),
|
||||
Protocol: string(PolicyRuleProtocolALL),
|
||||
Domains: route.Domains,
|
||||
IsDynamic: route.IsDynamic(),
|
||||
RouteID: route.ID,
|
||||
}
|
||||
|
||||
rules = append(rules, &rule)
|
||||
|
||||
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
|
||||
if route.IsDynamic() {
|
||||
ruleV6 := rule
|
||||
ruleV6.SourceRanges = []string{"::/0"}
|
||||
rules = append(rules, &ruleV6)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
|
||||
// and returns a list of policies that have rules with destinations matching the specified groups.
|
||||
func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
|
||||
routePolicies := make([]*Policy, 0)
|
||||
for _, groupID := range accessControlGroups {
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
for _, rule := range policy.Rules {
|
||||
exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool {
|
||||
return groupID == group.ID
|
||||
})
|
||||
if exist {
|
||||
routePolicies = append(routePolicies, policy)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routePolicies
|
||||
}
|
||||
|
||||
// GetPeerNetworkResourceFirewallRules gets the network resources firewall rules associated with a routing peer ID for the account.
|
||||
func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, routes []*route.Route, resourcePolicies map[string][]*Policy) []*RouteFirewallRule {
|
||||
routesFirewallRules := make([]*RouteFirewallRule, 0)
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Peer != peer.Key {
|
||||
continue
|
||||
}
|
||||
resourceAppliedPolicies := resourcePolicies[string(route.GetResourceID())]
|
||||
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
|
||||
|
||||
rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)
|
||||
for _, rule := range rules {
|
||||
if len(rule.SourceRanges) > 0 {
|
||||
routesFirewallRules = append(routesFirewallRules, rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
}
|
||||
|
||||
// getNetworkResourceGroups retrieves all groups associated with the given network resource.
|
||||
func (a *Account) getNetworkResourceGroups(resourceID string) []*Group {
|
||||
var networkResourceGroups []*Group
|
||||
@@ -1377,91 +734,6 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
|
||||
return resourcePolicies
|
||||
}
|
||||
|
||||
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
|
||||
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
|
||||
var isRoutingPeer bool
|
||||
var routes []*route.Route
|
||||
allSourcePeers := make(map[string]struct{}, len(a.Peers))
|
||||
|
||||
for _, resource := range a.NetworkResources {
|
||||
if !resource.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
var addSourcePeers bool
|
||||
|
||||
networkRoutingPeers, exists := routers[resource.NetworkID]
|
||||
if exists {
|
||||
if router, ok := networkRoutingPeers[peerID]; ok {
|
||||
isRoutingPeer, addSourcePeers = true, true
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...)
|
||||
}
|
||||
}
|
||||
|
||||
addedResourceRoute := false
|
||||
for _, policy := range resourcePolicies[resource.ID] {
|
||||
var peers []string
|
||||
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
|
||||
peers = []string{policy.Rules[0].SourceResource.ID}
|
||||
} else {
|
||||
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
|
||||
}
|
||||
if addSourcePeers {
|
||||
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
|
||||
allSourcePeers[pID] = struct{}{}
|
||||
}
|
||||
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
|
||||
// add routes for the resource if the peer is in the distribution group
|
||||
for peerId, router := range networkRoutingPeers {
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
|
||||
}
|
||||
addedResourceRoute = true
|
||||
}
|
||||
if addedResourceRoute {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return isRoutingPeer, routes, allSourcePeers
|
||||
}
|
||||
|
||||
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
|
||||
var dest []string
|
||||
for _, peerID := range inputPeers {
|
||||
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
|
||||
dest = append(dest, peerID)
|
||||
}
|
||||
}
|
||||
return dest
|
||||
}
|
||||
|
||||
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
|
||||
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
|
||||
for _, groupID := range groups {
|
||||
group := a.GetGroup(groupID)
|
||||
if group == nil {
|
||||
log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id)
|
||||
continue
|
||||
}
|
||||
|
||||
if group.IsGroupAll() || len(groups) == 1 {
|
||||
return group.Peers
|
||||
}
|
||||
|
||||
for _, peerID := range group.Peers {
|
||||
peerIDs[peerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(peerIDs))
|
||||
for peerID := range peerIDs {
|
||||
ids = append(ids, peerID)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// getNetworkResources filters and returns a list of network resources associated with the given network ID.
|
||||
func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource {
|
||||
var resources []*resourceTypes.NetworkResource
|
||||
@@ -1527,22 +799,6 @@ func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// getNetworkResourcesRoutes convert the network resources list to routes list.
|
||||
func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route {
|
||||
resourceAppliedPolicies := resourcePolicies[resource.ID]
|
||||
|
||||
var routes []*route.Route
|
||||
// distribute the resource routes only if there is policy applied to it
|
||||
if len(resourceAppliedPolicies) > 0 {
|
||||
peer := a.GetPeer(peerId)
|
||||
if peer != nil {
|
||||
routes = append(routes, resource.ToRoute(peer, router))
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter {
|
||||
routers := make(map[string]map[string]*routerTypes.NetworkRouter)
|
||||
|
||||
@@ -1573,28 +829,6 @@ func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.Net
|
||||
return routers
|
||||
}
|
||||
|
||||
// getPoliciesSourcePeers collects all unique peers from the source groups defined in the given policies.
|
||||
func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[string]struct{} {
|
||||
sourcePeers := make(map[string]struct{})
|
||||
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group := groups[sourceGroup]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, peer := range group.Peers {
|
||||
sourcePeers[peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sourcePeers
|
||||
}
|
||||
|
||||
// AddAllGroup to account object if it doesn't exist
|
||||
func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
||||
if len(a.Groups) == 0 {
|
||||
@@ -1638,70 +872,3 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||
var expanded []*FirewallRule
|
||||
|
||||
if len(rule.Ports) > 0 {
|
||||
for _, port := range rule.Ports {
|
||||
fr := base
|
||||
fr.Port = port
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
return expanded
|
||||
}
|
||||
|
||||
supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion)
|
||||
for _, portRange := range rule.PortRanges {
|
||||
fr := base
|
||||
|
||||
if supportPortRanges {
|
||||
fr.PortRange = portRange
|
||||
} else {
|
||||
// Peer doesn't support port ranges, only allow single-port ranges
|
||||
if portRange.Start != portRange.End {
|
||||
continue
|
||||
}
|
||||
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
|
||||
}
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
|
||||
return expanded
|
||||
}
|
||||
|
||||
// peerSupportsPortRanges checks if the peer version supports port ranges.
|
||||
func peerSupportsPortRanges(peerVer string) bool {
|
||||
if strings.Contains(peerVer, "dev") {
|
||||
return true
|
||||
}
|
||||
|
||||
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
|
||||
return err == nil && meetMinVer
|
||||
}
|
||||
|
||||
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
|
||||
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord {
|
||||
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
|
||||
peerIPs := make(map[string]struct{})
|
||||
|
||||
// Add peer's own IP to include its own DNS records
|
||||
peerIPs[peer.IP.String()] = struct{}{}
|
||||
|
||||
for _, peerToConnect := range peersToConnect {
|
||||
peerIPs[peerToConnect.IP.String()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, expiredPeer := range expiredPeers {
|
||||
peerIPs[expiredPeer.IP.String()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
if _, exists := peerIPs[record.RData]; exists {
|
||||
filteredRecords = append(filteredRecords, record)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredRecords
|
||||
}
|
||||
|
||||
43
management/server/types/holder.go
Normal file
43
management/server/types/holder.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Holder struct {
|
||||
mu sync.RWMutex
|
||||
accounts map[string]*Account
|
||||
}
|
||||
|
||||
func NewHolder() *Holder {
|
||||
return &Holder{
|
||||
accounts: make(map[string]*Account),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Holder) GetAccount(id string) *Account {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.accounts[id]
|
||||
}
|
||||
|
||||
func (h *Holder) AddAccount(account *Account) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.accounts[account.Id] = account
|
||||
}
|
||||
|
||||
func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if acc, ok := h.accounts[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
account, err := accGetter(context.Background(), id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.accounts[id] = account
|
||||
return account, nil
|
||||
}
|
||||
@@ -9,12 +9,7 @@ import (
|
||||
|
||||
"github.com/c-robinson/iplib"
|
||||
"github.com/rs/xid"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -29,40 +24,6 @@ const (
|
||||
AllowedIPsFormat = "%s/32"
|
||||
)
|
||||
|
||||
type NetworkMap struct {
|
||||
Peers []*nbpeer.Peer
|
||||
Network *Network
|
||||
Routes []*route.Route
|
||||
DNSConfig nbdns.Config
|
||||
OfflinePeers []*nbpeer.Peer
|
||||
FirewallRules []*FirewallRule
|
||||
RoutesFirewallRules []*RouteFirewallRule
|
||||
ForwardingRules []*ForwardingRule
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) Merge(other *NetworkMap) {
|
||||
nm.Peers = mergeUniquePeersByID(nm.Peers, other.Peers)
|
||||
nm.Routes = util.MergeUnique(nm.Routes, other.Routes)
|
||||
nm.OfflinePeers = mergeUniquePeersByID(nm.OfflinePeers, other.OfflinePeers)
|
||||
nm.FirewallRules = util.MergeUnique(nm.FirewallRules, other.FirewallRules)
|
||||
nm.RoutesFirewallRules = util.MergeUnique(nm.RoutesFirewallRules, other.RoutesFirewallRules)
|
||||
nm.ForwardingRules = util.MergeUnique(nm.ForwardingRules, other.ForwardingRules)
|
||||
}
|
||||
|
||||
func mergeUniquePeersByID(peers1, peers2 []*nbpeer.Peer) []*nbpeer.Peer {
|
||||
result := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers1 {
|
||||
result[peer.ID] = peer
|
||||
}
|
||||
for _, peer := range peers2 {
|
||||
if _, ok := result[peer.ID]; !ok {
|
||||
result[peer.ID] = peer
|
||||
}
|
||||
}
|
||||
|
||||
return maps.Values(result)
|
||||
}
|
||||
|
||||
type ForwardingRule struct {
|
||||
RuleProtocol string
|
||||
DestinationPorts RulePortRange
|
||||
|
||||
920
management/server/types/networkmap.go
Normal file
920
management/server/types/networkmap.go
Normal file
@@ -0,0 +1,920 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type NetworkMap struct {
|
||||
Peers []*nbpeer.Peer
|
||||
Network *Network
|
||||
Routes []*route.Route
|
||||
DNSConfig nbdns.Config
|
||||
OfflinePeers []*nbpeer.Peer
|
||||
FirewallRules []*FirewallRule
|
||||
RoutesFirewallRules []*RouteFirewallRule
|
||||
ForwardingRules []*ForwardingRule
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) Merge(other *NetworkMap) {
|
||||
nm.Peers = mergeUniquePeersByID(nm.Peers, other.Peers)
|
||||
nm.Routes = util.MergeUnique(nm.Routes, other.Routes)
|
||||
nm.OfflinePeers = mergeUniquePeersByID(nm.OfflinePeers, other.OfflinePeers)
|
||||
nm.FirewallRules = util.MergeUnique(nm.FirewallRules, other.FirewallRules)
|
||||
nm.RoutesFirewallRules = util.MergeUnique(nm.RoutesFirewallRules, other.RoutesFirewallRules)
|
||||
nm.ForwardingRules = util.MergeUnique(nm.ForwardingRules, other.ForwardingRules)
|
||||
}
|
||||
|
||||
// TODO optimize
|
||||
func mergeUniquePeersByID(peers1, peers2 []*nbpeer.Peer) []*nbpeer.Peer {
|
||||
result := make(map[string]*nbpeer.Peer)
|
||||
for _, peer := range peers1 {
|
||||
result[peer.ID] = peer
|
||||
}
|
||||
for _, peer := range peers2 {
|
||||
if _, ok := result[peer.ID]; !ok {
|
||||
result[peer.ID] = peer
|
||||
}
|
||||
}
|
||||
|
||||
return maps.Values(result)
|
||||
}
|
||||
|
||||
// GetPeerNetworkMap returns the networkmap for the given peer ID.
|
||||
func (a *Account) GetPeerNetworkMap(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
|
||||
peer := a.Peers[peerID]
|
||||
if peer == nil {
|
||||
return &NetworkMap{
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peerID]; !ok {
|
||||
return &NetworkMap{
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
var expiredPeers []*nbpeer.Peer
|
||||
for _, p := range aclPeers {
|
||||
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
|
||||
if a.Settings.PeerLoginExpirationEnabled && expired {
|
||||
expiredPeers = append(expiredPeers, p)
|
||||
continue
|
||||
}
|
||||
peersToConnect = append(peersToConnect, p)
|
||||
}
|
||||
|
||||
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect)
|
||||
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
|
||||
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
|
||||
var networkResourcesFirewallRules []*RouteFirewallRule
|
||||
if isRouter {
|
||||
networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies)
|
||||
}
|
||||
peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers)
|
||||
|
||||
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: dnsManagementStatus,
|
||||
}
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: peersCustomZone.Domain,
|
||||
Records: records,
|
||||
})
|
||||
}
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
}
|
||||
|
||||
nm := &NetworkMap{
|
||||
Peers: peersToConnectIncludingRouters,
|
||||
Network: a.Network.Copy(),
|
||||
Routes: slices.Concat(networkResourcesRoutes, routesUpdate),
|
||||
DNSConfig: dnsUpdate,
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
||||
}
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules))
|
||||
metrics.CountNetworkMapObjects(objectCount)
|
||||
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
if objectCount > 5000 {
|
||||
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
|
||||
"peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d",
|
||||
a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules))
|
||||
}
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
// GetPeerConnectionResources for a given peer
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
||||
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap)
|
||||
|
||||
if rule.Bidirectional {
|
||||
if peerInSources {
|
||||
generateResources(rule, destinationPeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionOUT)
|
||||
}
|
||||
}
|
||||
|
||||
if peerInSources {
|
||||
generateResources(rule, destinationPeers, FirewallRuleDirectionOUT)
|
||||
}
|
||||
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return getAccumulatedResources()
|
||||
}
|
||||
|
||||
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
||||
//
|
||||
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
|
||||
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
|
||||
// generated. The accumulator function returns the result of all the generator calls.
|
||||
func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
|
||||
rulesExists := make(map[string]struct{})
|
||||
peersExists := make(map[string]struct{})
|
||||
rules := make([]*FirewallRule, 0)
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
|
||||
all, err := a.GetGroupAll()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
|
||||
all = &Group{}
|
||||
}
|
||||
|
||||
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
|
||||
isAll := (len(all.Peers) - 1) == len(groupPeers)
|
||||
for _, peer := range groupPeers {
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := peersExists[peer.ID]; !ok {
|
||||
peers = append(peers, peer)
|
||||
peersExists[peer.ID] = struct{}{}
|
||||
}
|
||||
|
||||
fr := FirewallRule{
|
||||
PolicyID: rule.ID,
|
||||
PeerIP: peer.IP.String(),
|
||||
Direction: direction,
|
||||
Action: string(rule.Action),
|
||||
Protocol: string(rule.Protocol),
|
||||
}
|
||||
|
||||
if isAll {
|
||||
fr.PeerIP = "0.0.0.0"
|
||||
}
|
||||
|
||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
|
||||
if _, ok := rulesExists[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
rulesExists[ruleID] = struct{}{}
|
||||
|
||||
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
|
||||
rules = append(rules, &fr)
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
|
||||
}
|
||||
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
return peers, rules
|
||||
}
|
||||
}
|
||||
|
||||
// getAllPeersFromGroups for given peer ID and list of groups
|
||||
//
|
||||
// Returns a list of peers from specified groups that pass specified posture checks
|
||||
// and a boolean indicating if the supplied peer ID exists within these groups.
|
||||
//
|
||||
// Important: Posture checks are applicable only to source group peers,
|
||||
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
||||
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
peerInGroups := false
|
||||
uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups)
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
|
||||
for _, p := range uniquePeerIDs {
|
||||
peer, ok := a.Peers[p]
|
||||
if !ok || peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// validate the peer based on policy posture checks applied
|
||||
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
if !isValid {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peer.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if peer.ID == peerID {
|
||||
peerInGroups = true
|
||||
continue
|
||||
}
|
||||
|
||||
filteredPeers = append(filteredPeers, peer)
|
||||
}
|
||||
|
||||
return filteredPeers, peerInGroups
|
||||
}
|
||||
|
||||
// validatePostureChecksOnPeer validates the posture checks on a peer
|
||||
func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool {
|
||||
peer, ok := a.Peers[peerID]
|
||||
if !ok && peer == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, postureChecksID := range sourcePostureChecksID {
|
||||
postureChecks := a.GetPostureChecks(postureChecksID)
|
||||
if postureChecks == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, check := range postureChecks.GetChecks() {
|
||||
isValid, err := check.Check(ctx, *peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
|
||||
}
|
||||
if !isValid {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||
expanded := make([]*FirewallRule, 0, len(rule.Ports)+len(rule.PortRanges))
|
||||
|
||||
if len(rule.Ports) > 0 {
|
||||
for _, port := range rule.Ports {
|
||||
fr := base
|
||||
fr.Port = port
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
return expanded
|
||||
}
|
||||
|
||||
supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion)
|
||||
for _, portRange := range rule.PortRanges {
|
||||
fr := base
|
||||
|
||||
if supportPortRanges {
|
||||
fr.PortRange = portRange
|
||||
} else {
|
||||
// Peer doesn't support port ranges, only allow single-port ranges
|
||||
if portRange.Start != portRange.End {
|
||||
continue
|
||||
}
|
||||
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
|
||||
}
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
|
||||
return expanded
|
||||
}
|
||||
|
||||
// peerSupportsPortRanges checks if the peer version supports port ranges.
|
||||
func peerSupportsPortRanges(peerVer string) bool {
|
||||
if strings.Contains(peerVer, "dev") {
|
||||
return true
|
||||
}
|
||||
|
||||
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
|
||||
return err == nil && meetMinVer
|
||||
}
|
||||
|
||||
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
|
||||
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
|
||||
var isRoutingPeer bool
|
||||
var routes []*route.Route
|
||||
allSourcePeers := make(map[string]struct{}, len(a.Peers))
|
||||
|
||||
for _, resource := range a.NetworkResources {
|
||||
if !resource.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
var addSourcePeers bool
|
||||
|
||||
networkRoutingPeers, exists := routers[resource.NetworkID]
|
||||
if exists {
|
||||
if router, ok := networkRoutingPeers[peerID]; ok {
|
||||
isRoutingPeer, addSourcePeers = true, true
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...)
|
||||
}
|
||||
}
|
||||
|
||||
addedResourceRoute := false
|
||||
for _, policy := range resourcePolicies[resource.ID] {
|
||||
peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
|
||||
if addSourcePeers {
|
||||
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
|
||||
allSourcePeers[pID] = struct{}{}
|
||||
}
|
||||
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
|
||||
// add routes for the resource if the peer is in the distribution group
|
||||
for peerId, router := range networkRoutingPeers {
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
|
||||
}
|
||||
addedResourceRoute = true
|
||||
}
|
||||
if addedResourceRoute {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return isRoutingPeer, routes, allSourcePeers
|
||||
}
|
||||
|
||||
// getNetworkResourcesRoutes convert the network resources list to routes list.
|
||||
func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route {
|
||||
resourceAppliedPolicies := resourcePolicies[resource.ID]
|
||||
|
||||
var routes []*route.Route
|
||||
// distribute the resource routes only if there is policy applied to it
|
||||
if len(resourceAppliedPolicies) > 0 {
|
||||
peer := a.GetPeer(peerId)
|
||||
if peer != nil {
|
||||
routes = append(routes, resource.ToRoute(peer, router))
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
|
||||
var dest []string
|
||||
for _, peerID := range inputPeers {
|
||||
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
|
||||
dest = append(dest, peerID)
|
||||
}
|
||||
}
|
||||
return dest
|
||||
}
|
||||
|
||||
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
|
||||
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
|
||||
for _, groupID := range groups {
|
||||
group := a.GetGroup(groupID)
|
||||
if group == nil {
|
||||
log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id)
|
||||
continue
|
||||
}
|
||||
|
||||
if group.IsGroupAll() || len(groups) == 1 {
|
||||
return group.Peers
|
||||
}
|
||||
|
||||
for _, peerID := range group.Peers {
|
||||
peerIDs[peerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(peerIDs))
|
||||
for peerID := range peerIDs {
|
||||
ids = append(ids, peerID)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
|
||||
func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
|
||||
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
|
||||
|
||||
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
for _, route := range enabledRoutes {
|
||||
// If no access control groups are specified, accept all traffic.
|
||||
if len(route.AccessControlGroups) == 0 {
|
||||
defaultPermit := getDefaultPermit(route)
|
||||
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
|
||||
continue
|
||||
}
|
||||
|
||||
distributionPeers := a.getDistributionGroupsPeers(route)
|
||||
|
||||
for _, accessGroup := range route.AccessControlGroups {
|
||||
policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup})
|
||||
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
|
||||
routesFirewallRules = append(routesFirewallRules, rules...)
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
}
|
||||
|
||||
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
|
||||
var fwRules []*RouteFirewallRule
|
||||
for _, policy := range policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap)
|
||||
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
|
||||
fwRules = append(fwRules, rules...)
|
||||
}
|
||||
}
|
||||
return fwRules
|
||||
}
|
||||
|
||||
func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
|
||||
distPeersWithPolicy := make(map[string]struct{})
|
||||
for _, id := range rule.Sources {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
if pID == peerID {
|
||||
continue
|
||||
}
|
||||
_, distPeer := distributionPeers[pID]
|
||||
_, valid := validatedPeersMap[pID]
|
||||
if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) {
|
||||
distPeersWithPolicy[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
|
||||
for pID := range distPeersWithPolicy {
|
||||
peer := a.Peers[pID]
|
||||
if peer == nil {
|
||||
continue
|
||||
}
|
||||
distributionGroupPeers = append(distributionGroupPeers, peer)
|
||||
}
|
||||
return distributionGroupPeers
|
||||
}
|
||||
|
||||
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
|
||||
distPeers := make(map[string]struct{})
|
||||
for _, id := range route.Groups {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
distPeers[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return distPeers
|
||||
}
|
||||
|
||||
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
|
||||
var rules []*RouteFirewallRule
|
||||
|
||||
sources := []string{"0.0.0.0/0"}
|
||||
if route.Network.Addr().Is6() {
|
||||
sources = []string{"::/0"}
|
||||
}
|
||||
rule := RouteFirewallRule{
|
||||
SourceRanges: sources,
|
||||
Action: string(PolicyTrafficActionAccept),
|
||||
Destination: route.Network.String(),
|
||||
Protocol: string(PolicyRuleProtocolALL),
|
||||
Domains: route.Domains,
|
||||
IsDynamic: route.IsDynamic(),
|
||||
RouteID: route.ID,
|
||||
}
|
||||
|
||||
rules = append(rules, &rule)
|
||||
|
||||
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
|
||||
if route.IsDynamic() {
|
||||
ruleV6 := rule
|
||||
ruleV6.SourceRanges = []string{"::/0"}
|
||||
rules = append(rules, &ruleV6)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
|
||||
// and returns a list of policies that have rules with destinations matching the specified groups.
|
||||
func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
|
||||
routePolicies := make([]*Policy, 0)
|
||||
for _, groupID := range accessControlGroups {
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, policy := range account.Policies {
|
||||
for _, rule := range policy.Rules {
|
||||
exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool {
|
||||
return groupID == group.ID
|
||||
})
|
||||
if exist {
|
||||
routePolicies = append(routePolicies, policy)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routePolicies
|
||||
}
|
||||
|
||||
// GetPeerNetworkResourceFirewallRules gets the network resources firewall rules associated with a routing peer ID for the account.
|
||||
func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, routes []*route.Route, resourcePolicies map[string][]*Policy) []*RouteFirewallRule {
|
||||
routesFirewallRules := make([]*RouteFirewallRule, 0)
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Peer != peer.Key {
|
||||
continue
|
||||
}
|
||||
resourceAppliedPolicies := resourcePolicies[string(route.GetResourceID())]
|
||||
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
|
||||
|
||||
rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)
|
||||
for _, rule := range rules {
|
||||
if len(rule.SourceRanges) > 0 {
|
||||
routesFirewallRules = append(routesFirewallRules, rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
}
|
||||
|
||||
// getPoliciesSourcePeers collects all unique peers from the source groups defined in the given policies.
|
||||
func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[string]struct{} {
|
||||
sourcePeers := make(map[string]struct{})
|
||||
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group := groups[sourceGroup]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, peer := range group.Peers {
|
||||
sourcePeers[peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sourcePeers
|
||||
}
|
||||
|
||||
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
// from the ACL peers that have distribution groups associated with the peer ID.
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route {
|
||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
peerRoutesMembership := make(LookupMap)
|
||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
||||
}
|
||||
|
||||
groupListMap := a.GetPeerGroups(peerID)
|
||||
for _, peer := range aclPeers {
|
||||
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
|
||||
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
|
||||
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
|
||||
routes = append(routes, filteredRoutes...)
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerGroups(peerID string) LookupMap {
|
||||
groupList := make(LookupMap)
|
||||
for groupID, group := range a.Groups {
|
||||
for _, id := range group.Peers {
|
||||
if id == peerID {
|
||||
groupList[groupID] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return groupList
|
||||
}
|
||||
|
||||
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
|
||||
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
_, found := peerMemberships[string(r.GetHAUniqueID())]
|
||||
if !found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
|
||||
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
for _, groupID := range r.Groups {
|
||||
_, found := groupListMap[groupID]
|
||||
if found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
// If the given is not a routing peer, then the lists are empty.
|
||||
func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
|
||||
|
||||
peer := a.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
// log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
seenRoute := make(map[route.ID]struct{})
|
||||
|
||||
takeRoute := func(r *route.Route, id string) {
|
||||
if _, ok := seenRoute[r.ID]; ok {
|
||||
return
|
||||
}
|
||||
seenRoute[r.ID] = struct{}{}
|
||||
|
||||
if r.Enabled {
|
||||
r.Peer = peer.Key
|
||||
enabledRoutes = append(enabledRoutes, r)
|
||||
return
|
||||
}
|
||||
disabledRoutes = append(disabledRoutes, r)
|
||||
}
|
||||
|
||||
for _, r := range a.Routes {
|
||||
for _, groupID := range r.PeerGroups {
|
||||
group := a.GetGroup(groupID)
|
||||
if group == nil {
|
||||
log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
|
||||
continue
|
||||
}
|
||||
for _, id := range group.Peers {
|
||||
if id != peerID {
|
||||
continue
|
||||
}
|
||||
|
||||
newPeerRoute := r.Copy()
|
||||
newPeerRoute.Peer = id
|
||||
newPeerRoute.PeerGroups = nil
|
||||
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
|
||||
takeRoute(newPeerRoute, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
if r.Peer == peerID {
|
||||
takeRoute(r.Copy(), peerID)
|
||||
}
|
||||
}
|
||||
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
func (a *Account) addNetworksRoutingPeers(
|
||||
networkResourcesRoutes []*route.Route,
|
||||
peer *nbpeer.Peer,
|
||||
peersToConnect []*nbpeer.Peer,
|
||||
expiredPeers []*nbpeer.Peer,
|
||||
isRouter bool,
|
||||
sourcePeers map[string]struct{},
|
||||
) []*nbpeer.Peer {
|
||||
|
||||
networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
|
||||
for _, r := range networkResourcesRoutes {
|
||||
networkRoutesPeers[r.PeerID] = struct{}{}
|
||||
}
|
||||
|
||||
delete(sourcePeers, peer.ID)
|
||||
delete(networkRoutesPeers, peer.ID)
|
||||
|
||||
for _, existingPeer := range peersToConnect {
|
||||
delete(sourcePeers, existingPeer.ID)
|
||||
delete(networkRoutesPeers, existingPeer.ID)
|
||||
}
|
||||
for _, expPeer := range expiredPeers {
|
||||
delete(sourcePeers, expPeer.ID)
|
||||
delete(networkRoutesPeers, expPeer.ID)
|
||||
}
|
||||
|
||||
missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
|
||||
if isRouter {
|
||||
for p := range sourcePeers {
|
||||
missingPeers[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
for p := range networkRoutesPeers {
|
||||
missingPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
for p := range missingPeers {
|
||||
if missingPeer := a.Peers[p]; missingPeer != nil {
|
||||
peersToConnect = append(peersToConnect, missingPeer)
|
||||
}
|
||||
}
|
||||
|
||||
return peersToConnect
|
||||
}
|
||||
|
||||
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
enabled := true
|
||||
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
|
||||
_, found := peerGroups[groupID]
|
||||
if found {
|
||||
enabled = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||
groupList := account.GetPeerGroups(peerID)
|
||||
|
||||
var peerNSGroups []*nbdns.NameServerGroup
|
||||
|
||||
for _, nsGroup := range account.NameServerGroups {
|
||||
if !nsGroup.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, gID := range nsGroup.Groups {
|
||||
_, found := groupList[gID]
|
||||
if found {
|
||||
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
|
||||
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peerNSGroups
|
||||
}
|
||||
|
||||
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
|
||||
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if peer.IP.Equal(ns.IP.AsSlice()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) {
|
||||
if a.NetworkMapCache != nil {
|
||||
return
|
||||
}
|
||||
a.nmapInitOnce.Do(func() {
|
||||
a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
validatedPeers map[string]struct{},
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
|
||||
if a.NetworkMapCache == nil {
|
||||
return nil
|
||||
}
|
||||
return a.NetworkMapCache.OnPeerAddedIncremental(peerId)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
|
||||
if a.NetworkMapCache == nil {
|
||||
return nil
|
||||
}
|
||||
return a.NetworkMapCache.OnPeerDeleted(peerId)
|
||||
}
|
||||
|
||||
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
|
||||
if a.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
a.NetworkMapCache.UpdatePeer(peer)
|
||||
}
|
||||
|
||||
func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
}
|
||||
|
||||
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
|
||||
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord {
|
||||
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
|
||||
peerIPs := make(map[string]struct{})
|
||||
|
||||
// Add peer's own IP to include its own DNS records
|
||||
peerIPs[peer.IP.String()] = struct{}{}
|
||||
|
||||
for _, peerToConnect := range peersToConnect {
|
||||
peerIPs[peerToConnect.IP.String()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, expiredPeer := range expiredPeers {
|
||||
peerIPs[expiredPeer.IP.String()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
if _, exists := peerIPs[record.RData]; exists {
|
||||
filteredRecords = append(filteredRecords, record)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredRecords
|
||||
}
|
||||
1069
management/server/types/networkmap_golden_test.go
Normal file
1069
management/server/types/networkmap_golden_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1761
management/server/types/networkmapbuilder.go
Normal file
1761
management/server/types/networkmapbuilder.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,16 +7,14 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const channelBufferSize = 100
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
NetworkMap *types.NetworkMap
|
||||
Update *proto.SyncResponse
|
||||
}
|
||||
|
||||
type PeersUpdateManager struct {
|
||||
|
||||
@@ -961,6 +961,10 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
peer.UserID, peer.ID, accountID,
|
||||
activity.PeerLoginExpired, peer.EventMeta(dnsDomain),
|
||||
)
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
}
|
||||
|
||||
if len(peerIDs) != 0 {
|
||||
|
||||
Reference in New Issue
Block a user