[management] Refactor network map controller (#4789)

This commit is contained in:
Pascal Fischer
2025-12-02 12:34:28 +01:00
committed by GitHub
parent 52948ccd61
commit 7193bd2da7
45 changed files with 819 additions and 492 deletions

View File

@@ -15,6 +15,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
clientProto "github.com/netbirdio/netbird/client/proto"
@@ -24,8 +26,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -116,15 +116,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), config)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
if err != nil {
t.Fatal(err)
}

View File

@@ -30,11 +30,12 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
@@ -54,7 +55,6 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -1628,14 +1628,17 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, "", err
}

View File

@@ -17,11 +17,12 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -35,7 +36,6 @@ import (
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -316,14 +316,17 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, "", err
}

2
go.mod
View File

@@ -64,7 +64,7 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63 h1:ecs4GMANgObopiy29zMmz2dIdOTJMwezUbrFy+zfSwE=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63/go.mod h1:JIWpjbCgDvZIt45C9vYpikU2gRXeDWrN7SiyGYd3Qrc=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -19,6 +19,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
@@ -42,6 +43,7 @@ type Controller struct {
accountManagerMetrics *telemetry.AccountManagerMetrics
peersUpdateManager network_map.PeersUpdateManager
settingsManager settings.Manager
EphemeralPeersManager ephemeral.Manager
accountUpdateLocks sync.Map
sendAccountUpdateLocks sync.Map
@@ -70,7 +72,7 @@ type bufferUpdate struct {
var _ network_map.Controller = (*Controller)(nil)
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, config *config.Config) *Controller {
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
if err != nil {
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
@@ -99,7 +101,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
dnsDomain: dnsDomain,
config: config,
proxyController: proxyController,
proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
@@ -107,6 +110,31 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
}
}
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
if err != nil {
return nil, fmt.Errorf("failed to get peer %s: %v", peerID, err)
}
c.EphemeralPeersManager.OnPeerConnected(ctx, peer)
return c.peersUpdateManager.CreateChannel(ctx, peerID), nil
}
func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) {
c.peersUpdateManager.CloseChannel(ctx, peerID)
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peer %s: %v", peerID, err)
return
}
c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
}
func (c *Controller) CountStreams() int {
return c.peersUpdateManager.CountStreams()
}
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
var (
@@ -366,38 +394,6 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountId)
if err != nil {
return err
}
peers, err := c.repo.GetAccountPeers(ctx, accountId)
if err != nil {
return err
}
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
})
c.peersUpdateManager.CloseChannel(ctx, peerId)
return nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
@@ -698,35 +694,83 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
return false, nil
}
func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) {
c.UpdatePeerInNetworkMapCache(accountId, peer)
_ = c.bufferSendUpdateAccountPeers(context.Background(), accountId)
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
if err != nil {
return fmt.Errorf("failed to get peers by ids: %w", err)
}
for _, peer := range peers {
c.UpdatePeerInNetworkMapCache(accountID, peer)
}
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
}
return nil
}
func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
for _, peerID := range peerIDs {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
}
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return err
}
peers, err := c.repo.GetAccountPeers(ctx, accountID)
if err != nil {
return err
}
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
for _, peerID := range peerIDs {
c.peersUpdateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
})
c.peersUpdateManager.CloseChannel(ctx, peerID)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
continue
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
continue
}
}
}
@@ -778,10 +822,6 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
return networkMap, nil
}
func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) {
func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
}
func (c *Controller) IsConnected(peerID string) bool {
return c.peersUpdateManager.HasChannel(peerID)
}

View File

@@ -12,6 +12,8 @@ type Repository interface {
GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error)
GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
}
type repository struct {
@@ -37,3 +39,11 @@ func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*
func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
return r.store.GetAccountByPeerID(ctx, peerID)
}
func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
return r.store.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, peerIDs)
}
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}

View File

@@ -28,12 +28,12 @@ type Controller interface {
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
CountStreams() int
DeletePeer(ctx context.Context, accountId string, peerId string) error
OnPeerUpdated(accountId string, peer *nbpeer.Peer)
OnPeerAdded(ctx context.Context, accountID string, peerID string) error
OnPeerDeleted(ctx context.Context, accountID string, peerID string) error
DisconnectPeers(ctx context.Context, peerIDs []string)
IsConnected(peerID string) bool
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
}

View File

@@ -57,30 +57,30 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID an
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
}
// DeletePeer mocks base method.
func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error {
// CountStreams mocks base method.
func (m *MockController) CountStreams() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId)
ret0, _ := ret[0].(error)
ret := m.ctrl.Call(m, "CountStreams")
ret0, _ := ret[0].(int)
return ret0
}
// DeletePeer indicates an expected call of DeletePeer.
func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call {
// CountStreams indicates an expected call of CountStreams.
func (mr *MockControllerMockRecorder) CountStreams() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountStreams", reflect.TypeOf((*MockController)(nil).CountStreams))
}
// DisconnectPeers mocks base method.
func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) {
func (m *MockController) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs)
m.ctrl.Call(m, "DisconnectPeers", ctx, accountId, peerIDs)
}
// DisconnectPeers indicates an expected call of DisconnectPeers.
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, accountId, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, accountId, peerIDs)
}
// GetDNSDomain mocks base method.
@@ -130,58 +130,73 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
}
// IsConnected mocks base method.
func (m *MockController) IsConnected(peerID string) bool {
// OnPeerConnected mocks base method.
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsConnected", peerID)
ret0, _ := ret[0].(bool)
return ret0
ret := m.ctrl.Call(m, "OnPeerConnected", ctx, accountID, peerID)
ret0, _ := ret[0].(chan *UpdateMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsConnected indicates an expected call of IsConnected.
func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call {
// OnPeerConnected indicates an expected call of OnPeerConnected.
func (mr *MockControllerMockRecorder) OnPeerConnected(ctx, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerConnected", reflect.TypeOf((*MockController)(nil).OnPeerConnected), ctx, accountID, peerID)
}
// OnPeerAdded mocks base method.
func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error {
// OnPeerDisconnected mocks base method.
func (m *MockController) OnPeerDisconnected(ctx context.Context, accountID, peerID string) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID)
m.ctrl.Call(m, "OnPeerDisconnected", ctx, accountID, peerID)
}
// OnPeerDisconnected indicates an expected call of OnPeerDisconnected.
func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDisconnected", reflect.TypeOf((*MockController)(nil).OnPeerDisconnected), ctx, accountID, peerID)
}
// OnPeersAdded mocks base method.
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerAdded indicates an expected call of OnPeerAdded.
func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call {
// OnPeersAdded indicates an expected call of OnPeersAdded.
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
}
// OnPeerDeleted mocks base method.
func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error {
// OnPeersDeleted mocks base method.
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID)
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerDeleted indicates an expected call of OnPeerDeleted.
func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call {
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
}
// OnPeerUpdated mocks base method.
func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) {
// OnPeersUpdated mocks base method.
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnPeerUpdated", accountId, peer)
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerUpdated indicates an expected call of OnPeerUpdated.
func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call {
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
}
// StartWarmup mocks base method.

View File

@@ -2,10 +2,15 @@ package ephemeral
import (
"context"
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
const (
EphemeralLifeTime = 10 * time.Minute
)
type Manager interface {
LoadInitialPeers(ctx context.Context)
Stop()

View File

@@ -7,14 +7,15 @@ import (
log "github.com/sirupsen/logrus"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
)
const (
ephemeralLifeTime = 10 * time.Minute
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
cleanupWindow = 1 * time.Minute
)
@@ -33,11 +34,11 @@ type ephemeralPeer struct {
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
// in worst case we will get invalid error message in this manager.
// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted
// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted
// automatically. Inactivity means the peer disconnected from the Management server.
type EphemeralManager struct {
store store.Store
accountManager nbAccount.Manager
store store.Store
peersManager peers.Manager
headPeer *ephemeralPeer
tailPeer *ephemeralPeer
@@ -49,12 +50,12 @@ type EphemeralManager struct {
}
// NewEphemeralManager instantiate new EphemeralManager
func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *EphemeralManager {
func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager {
return &EphemeralManager{
store: store,
accountManager: accountManager,
store: store,
peersManager: peersManager,
lifeTime: ephemeralLifeTime,
lifeTime: ephemeral.EphemeralLifeTime,
cleanupWindow: cleanupWindow,
}
}
@@ -106,7 +107,7 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee
}
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
// is inactive it will be deleted after the ephemeralLifeTime period.
// is inactive it will be deleted after the EphemeralLifeTime period.
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
@@ -180,20 +181,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
e.peersLock.Unlock()
bufferAccountCall := make(map[string]struct{})
peerIDsPerAccount := make(map[string][]string)
for id, p := range deletePeers {
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
}
for accountID, peerIDs := range peerIDsPerAccount {
log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID)
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
} else {
bufferAccountCall[p.accountID] = struct{}{}
}
}
for accountID := range bufferAccountCall {
e.accountManager.BufferUpdateAccountPeers(ctx, accountID)
}
}
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {

View File

@@ -7,10 +7,13 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
nbAccount "github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
@@ -91,17 +94,27 @@ func TestNewManager(t *testing.T) {
}
store := &MockStore{}
am := MockAccountManager{
store: store,
}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, &am)
// Expect DeletePeers to be called for ephemeral peers
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeralLifeTime + 1)
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers {
@@ -119,19 +132,29 @@ func TestNewManagerPeerConnected(t *testing.T) {
}
store := &MockStore{}
am := MockAccountManager{
store: store,
}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, &am)
// Expect DeletePeers to be called for ephemeral peers (except the connected one)
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeralLifeTime + 1)
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + 1
@@ -150,15 +173,25 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
}
store := &MockStore{}
am := MockAccountManager{
store: store,
}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, &am)
// Expect DeletePeers to be called for the one disconnected peer
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v)
@@ -166,7 +199,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
}
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeralLifeTime + 1)
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + numberOfEphemeralPeers - 1
@@ -181,25 +214,63 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
testLifeTime = 1 * time.Second
testCleanupWindow = 100 * time.Millisecond
)
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
mockStore := &MockStore{}
account := newAccountWithId(context.Background(), "account", "", "", false)
mockStore.account = account
wg := &sync.WaitGroup{}
wg.Add(ephemeralPeers)
mockAM := &MockAccountManager{
store: mockStore,
wg: wg,
}
mockAM.wg = &sync.WaitGroup{}
mockAM.wg.Add(ephemeralPeers)
mgr := NewEphemeralManager(mockStore, mockAM)
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
// Set up expectation that DeletePeers will be called once with all peer IDs
peersManager.EXPECT().
DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
// Simulate the actual deletion behavior
for _, peerID := range peerIDs {
err := mockAM.DeletePeer(ctx, accountID, peerID, userID)
if err != nil {
return err
}
}
mockAM.BufferUpdateAccountPeers(ctx, accountID)
return nil
}).
Times(1)
mgr := NewEphemeralManager(mockStore, peersManager)
mgr.lifeTime = testLifeTime
mgr.cleanupWindow = testCleanupWindow
account := newAccountWithId(context.Background(), "account", "", "", false)
mockStore.account = account
// Add peers and disconnect them at slightly different times (within cleanup window)
for i := range ephemeralPeers {
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
mockStore.account.Peers[p.ID] = p
time.Sleep(testCleanupWindow / ephemeralPeers)
mgr.OnPeerDisconnected(context.Background(), p)
startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2))
}
mockAM.wg.Wait()
// Advance time past the lifetime to trigger cleanup
startTime = startTime.Add(testLifeTime + testCleanupWindow)
// Wait for all deletions to complete
wg.Wait()
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers")

View File

@@ -0,0 +1,162 @@
package peers
//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"fmt"
"time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type Manager interface {
GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
GetPeerAccountID(ctx context.Context, peerID string) (string, error)
GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error
SetNetworkMapController(networkMapController network_map.Controller)
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
SetAccountManager(accountManager account.Manager)
}
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
integratedPeerValidator integrated_validator.IntegratedValidator
accountManager account.Manager
networkMapController network_map.Controller
}
func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
return &managerImpl{
store: store,
permissionsManager: permissionsManager,
}
}
func (m *managerImpl) SetNetworkMapController(networkMapController network_map.Controller) {
m.networkMapController = networkMapController
}
func (m *managerImpl) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) {
m.integratedPeerValidator = integratedPeerValidator
}
func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
m.accountManager = accountManager
}
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
}
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
}
func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
}
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
dnsDomain := m.networkMapController.GetDNSDomain(settings)
for _, peerID := range peerIDs {
var eventsToStore []func()
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return err
}
if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) {
return nil
}
if err := transaction.RemovePeerFromAllGroups(ctx, peerID); err != nil {
return fmt.Errorf("failed to remove peer %s from groups", peerID)
}
if err := m.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil {
return err
}
peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return err
}
for _, rule := range peerPolicyRules {
policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID)
if err != nil {
return err
}
err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID)
if err != nil {
return err
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
})
}
if err = transaction.DeletePeer(ctx, accountID, peerID); err != nil {
return err
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
return nil
})
if err != nil {
return err
}
for _, event := range eventsToStore {
event()
}
}
return nil
}

View File

@@ -9,6 +9,9 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
network_map "github.com/netbirdio/netbird/management/internals/controllers/network_map"
account "github.com/netbirdio/netbird/management/server/account"
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
peer "github.com/netbirdio/netbird/management/server/peer"
)
@@ -35,6 +38,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// DeletePeers mocks base method.
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeletePeers", ctx, accountID, peerIDs, userID, checkConnected)
ret0, _ := ret[0].(error)
return ret0
}
// DeletePeers indicates an expected call of DeletePeers.
func (mr *MockManagerMockRecorder) DeletePeers(ctx, accountID, peerIDs, userID, checkConnected interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeers", reflect.TypeOf((*MockManager)(nil).DeletePeers), ctx, accountID, peerIDs, userID, checkConnected)
}
// GetAllPeers mocks base method.
func (m *MockManager) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
m.ctrl.T.Helper()
@@ -94,3 +111,39 @@ func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs)
}
// SetAccountManager mocks base method.
func (m *MockManager) SetAccountManager(accountManager account.Manager) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetAccountManager", accountManager)
}
// SetAccountManager indicates an expected call of SetAccountManager.
func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager)
}
// SetIntegratedPeerValidator mocks base method.
func (m *MockManager) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetIntegratedPeerValidator", integratedPeerValidator)
}
// SetIntegratedPeerValidator indicates an expected call of SetIntegratedPeerValidator.
func (mr *MockManagerMockRecorder) SetIntegratedPeerValidator(integratedPeerValidator interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIntegratedPeerValidator", reflect.TypeOf((*MockManager)(nil).SetIntegratedPeerValidator), integratedPeerValidator)
}
// SetNetworkMapController mocks base method.
func (m *MockManager) SetNetworkMapController(networkMapController network_map.Controller) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetworkMapController", networkMapController)
}
// SetNetworkMapController indicates an expected call of SetNetworkMapController.
func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
}

View File

@@ -57,7 +57,7 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics {
func (s *BaseServer) Store() store.Store {
return Create(s, func() store.Store {
store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false)
store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false)
if err != nil {
log.Fatalf("failed to create store: %v", err)
}
@@ -73,17 +73,17 @@ func (s *BaseServer) EventStore() activity.Store {
log.Fatalf("failed to initialize integration metrics: %v", err)
}
eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics)
eventStore, key, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
log.Fatalf("failed to initialize event store: %v", err)
}
if s.config.DataStoreEncryptionKey != key {
log.WithContext(context.Background()).Infof("update config with activity store key")
s.config.DataStoreEncryptionKey = key
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config)
if s.Config.DataStoreEncryptionKey != key {
log.WithContext(context.Background()).Infof("update Config with activity store key")
s.Config.DataStoreEncryptionKey = key
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.Config)
if err != nil {
log.Fatalf("failed to update config with activity store: %v", err)
log.Fatalf("failed to update Config with activity store: %v", err)
}
}
@@ -103,14 +103,14 @@ func (s *BaseServer) APIHandler() http.Handler {
func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server {
trustedPeers := s.config.ReverseProxy.TrustedPeers
trustedPeers := s.Config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers
}
trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount
trustedHTTPProxies := s.Config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := s.Config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
@@ -128,15 +128,15 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
if s.config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
if s.Config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
if err != nil {
log.Fatalf("failed to create certificate manager: %v", err)
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
} else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" {
tlsConfig, err := loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey)
if err != nil {
log.Fatalf("cannot load TLS credentials: %v", err)
}
@@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}

View File

@@ -9,17 +9,17 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
)
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
return Create(s, func() *update_channel.PeersUpdateManager {
return Create(s, func() network_map.PeersUpdateManager {
return update_channel.NewPeersUpdateManager(s.Metrics())
})
}
@@ -44,33 +44,37 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller {
})
}
func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager {
return Create(s, func() *grpc.TimeBasedAuthSecretsManager {
return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
func (s *BaseServer) SecretsManager() grpc.SecretsManager {
return Create(s, func() grpc.SecretsManager {
secretsManager, err := grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.Config.TURNConfig, s.Config.Relay, s.SettingsManager(), s.GroupsManager())
if err != nil {
log.Fatalf("failed to create secrets manager: %v", err)
}
return secretsManager
})
}
func (s *BaseServer) AuthManager() auth.Manager {
return Create(s, func() auth.Manager {
return auth.NewManager(s.Store(),
s.config.HttpConfig.AuthIssuer,
s.config.HttpConfig.AuthAudience,
s.config.HttpConfig.AuthKeysLocation,
s.config.HttpConfig.AuthUserIDClaim,
s.config.GetAuthAudiences(),
s.config.HttpConfig.IdpSignKeyRefreshEnabled)
s.Config.HttpConfig.AuthIssuer,
s.Config.HttpConfig.AuthAudience,
s.Config.HttpConfig.AuthKeysLocation,
s.Config.HttpConfig.AuthUserIDClaim,
s.Config.GetAuthAudiences(),
s.Config.HttpConfig.IdpSignKeyRefreshEnabled)
})
}
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
return Create(s, func() ephemeral.Manager {
return manager.NewEphemeralManager(s.Store(), s.AccountManager())
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
})
}
func (s *BaseServer) NetworkMapController() network_map.Controller {
return Create(s, func() *nmapcontroller.Controller {
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController(), s.config)
return Create(s, func() network_map.Controller {
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.DNSDomain(), s.ProxyController(), s.EphemeralManager(), s.Config)
})
}
@@ -79,3 +83,7 @@ func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
return server.NewAccountRequestBuffer(context.Background(), s.Store())
})
}
func (s *BaseServer) DNSDomain() string {
return s.dnsDomain
}

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/geolocation"
@@ -14,7 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users"
@@ -22,12 +23,12 @@ import (
func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
return Create(s, func() geolocation.Geolocation {
geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate)
geo, err := geolocation.NewGeolocation(context.Background(), s.Config.Datadir, !s.disableGeoliteUpdate)
if err != nil {
log.Fatalf("could not initialize geolocation service: %v", err)
}
log.Infof("geolocation service has been initialized from %s", s.config.Datadir)
log.Infof("geolocation service has been initialized from %s", s.Config.Datadir)
return geo
})
@@ -60,20 +61,22 @@ func (s *BaseServer) SettingsManager() settings.Manager {
func (s *BaseServer) PeersManager() peers.Manager {
return Create(s, func() peers.Manager {
return peers.NewManager(s.Store(), s.PermissionsManager())
manager := peers.NewManager(s.Store(), s.PermissionsManager())
s.AfterInit(func(s *BaseServer) {
manager.SetNetworkMapController(s.NetworkMapController())
manager.SetIntegratedPeerValidator(s.IntegratedValidator())
manager.SetAccountManager(s.AccountManager())
})
return manager
})
}
func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}
s.AfterInit(func(s *BaseServer) {
accountManager.SetEphemeralManager(s.EphemeralManager())
})
return accountManager
})
}
@@ -82,8 +85,8 @@ func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
if s.config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics())
if s.Config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil {
log.Fatalf("failed to create IDP manager: %v", err)
}

View File

@@ -41,10 +41,10 @@ type Server interface {
}
// Server holds the HTTP BaseServer instance.
// Add any additional fields you need, such as database connections, config, etc.
// Add any additional fields you need, such as database connections, Config, etc.
type BaseServer struct {
// config holds the server configuration
config *nbconfig.Config
// Config holds the server configuration
Config *nbconfig.Config
// container of dependencies, each dependency is identified by a unique string.
container map[string]any
// AfterInit is a function that will be called after the server is initialized
@@ -70,7 +70,7 @@ type BaseServer struct {
// NewServer initializes and configures a new Server instance
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
return &BaseServer{
config: config,
Config: config,
container: make(map[string]any),
dnsDomain: dnsDomain,
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
@@ -103,14 +103,14 @@ func (s *BaseServer) Start(ctx context.Context) error {
var tlsConfig *tls.Config
tlsEnabled := false
if s.config.HttpConfig.LetsEncryptDomain != "" {
s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
if s.Config.HttpConfig.LetsEncryptDomain != "" {
s.certManager, err = encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
if err != nil {
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
}
tlsEnabled = true
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
} else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey)
if err != nil {
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
return err
@@ -126,8 +126,8 @@ func (s *BaseServer) Start(ctx context.Context) error {
if !s.disableMetrics {
idpManager := "disabled"
if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" {
idpManager = s.config.IdpManagerConfig.ManagerType
if s.Config.IdpManagerConfig != nil && s.Config.IdpManagerConfig.ManagerType != "" {
idpManager = s.Config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager)
go metricsWorker.Run(srvCtx)

View File

@@ -24,7 +24,6 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/store"
@@ -55,15 +54,12 @@ const (
type Server struct {
accountManager account.Manager
settingsManager settings.Manager
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
peersUpdateManager network_map.PeersUpdateManager
config *nbconfig.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
ephemeralManager ephemeral.Manager
peerLocks sync.Map
authManager auth.Manager
config *nbconfig.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
peerLocks sync.Map
authManager auth.Manager
logBlockedPeers bool
blockPeersWithSameConfig bool
@@ -82,23 +78,16 @@ func NewServer(
config *nbconfig.Config,
accountManager account.Manager,
settingsManager settings.Manager,
peersUpdateManager network_map.PeersUpdateManager,
secretsManager SecretsManager,
appMetrics telemetry.AppMetrics,
ephemeralManager ephemeral.Manager,
authManager auth.Manager,
integratedPeerValidator integrated_validator.IntegratedValidator,
networkMapController network_map.Controller,
) (*Server, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(peersUpdateManager.CountStreams())
err := appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(networkMapController.CountStreams())
})
if err != nil {
return nil, err
@@ -120,16 +109,12 @@ func NewServer(
}
return &Server{
wgKey: key,
// peerKey -> event channel
peersUpdateManager: peersUpdateManager,
accountManager: accountManager,
settingsManager: settingsManager,
config: config,
secretsManager: secretsManager,
authManager: authManager,
appMetrics: appMetrics,
ephemeralManager: ephemeralManager,
logBlockedPeers: logBlockedPeers,
blockPeersWithSameConfig: blockPeersWithSameConfig,
integratedPeerValidator: integratedPeerValidator,
@@ -163,8 +148,14 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
nanos := int32(now.Nanosecond())
expiresAt := &timestamp.Timestamp{Seconds: secs, Nanos: nanos}
key, err := s.secretsManager.GetWGKey()
if err != nil {
log.WithContext(ctx).Errorf("failed to get wireguard key: %v", err)
return nil, errors.New("failed to get wireguard key")
}
return &proto.ServerKeyResponse{
Key: s.wgKey.PublicKey().String(),
Key: key.PublicKey().String(),
ExpiresAt: expiresAt,
}, nil
}
@@ -269,9 +260,13 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return err
}
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
s.ephemeralManager.OnPeerConnected(ctx, peer)
updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutines(ctx, accountID, peer)
return err
}
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
@@ -323,13 +318,19 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
key, err := s.secretsManager.GetWGKey()
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
@@ -348,9 +349,8 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
s.secretsManager.CancelRefresh(peer.ID)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
}
@@ -504,7 +504,12 @@ func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage,
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed)
key, err := s.secretsManager.GetWGKey()
if err != nil {
return wgtypes.Key{}, status.Errorf(codes.Internal, "failed processing request")
}
err = encryption.DecryptMessage(peerKey, key, req.Body, parsed)
if err != nil {
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
}
@@ -601,12 +606,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
// if the login request contains setup key then it is a registration request
if loginReq.GetSetupKey() != "" {
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart))
}
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
if err != nil {
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
@@ -615,14 +614,20 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
key, err := s.secretsManager.GetWGKey()
if err != nil {
log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, key, loginResp)
if err != nil {
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
return &proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
}, nil
}
@@ -715,14 +720,19 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
key, err := s.secretsManager.GetWGKey()
if err != nil {
return status.Errorf(codes.Internal, "failed getting server key")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, key, plainResp)
if err != nil {
return status.Errorf(codes.Internal, "error handling request")
}
sendStart := time.Now()
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
})
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
@@ -752,7 +762,12 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
key, err := s.secretsManager.GetWGKey()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get server key")
}
err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.DeviceAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.WithContext(ctx).Warn(errMSG)
@@ -782,13 +797,13 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
},
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information")
}
return &proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
}, nil
}
@@ -810,7 +825,12 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
key, err := s.secretsManager.GetWGKey()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get server key")
}
err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.PKCEAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.WithContext(ctx).Warn(errMSG)
@@ -838,13 +858,13 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
}
return &proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
WgPubKey: key.PublicKey().String(),
Body: encryptedResp,
}, nil
}

View File

@@ -73,15 +73,17 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &Server{
wgKey: testingServerKey,
secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey},
config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
key, err := mgmtServer.secretsManager.GetWGKey()
require.NoError(t, err, "should be able to get server key")
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message)
require.NoError(t, err, "should be able to encrypt message")
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
@@ -95,7 +97,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
if testCase.expectedComparisonFunc != nil {
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
require.NoError(t, err, "should be able to decrypt")
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)

View File

@@ -10,6 +10,7 @@ import (
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
@@ -29,6 +30,7 @@ type SecretsManager interface {
GenerateRelayToken() (*Token, error)
SetupRefresh(ctx context.Context, accountID, peerKey string)
CancelRefresh(peerKey string)
GetWGKey() (wgtypes.Key, error)
}
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
@@ -43,11 +45,17 @@ type TimeBasedAuthSecretsManager struct {
groupsManager groups.Manager
turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{}
wgKey wgtypes.Key
}
type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) (*TimeBasedAuthSecretsManager, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager,
turnCfg: turnCfg,
@@ -56,6 +64,7 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager
relayCancelMap: make(map[string]chan struct{}),
settingsManager: settingsManager,
groupsManager: groupsManager,
wgKey: key,
}
if turnCfg != nil {
@@ -81,7 +90,12 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager
}
}
return mgr
return mgr, nil
}
// GetWGKey returns WireGuard private key used to generate peer keys
func (m *TimeBasedAuthSecretsManager) GetWGKey() (wgtypes.Key, error) {
return m.wgKey, nil
}
// GenerateTurnToken generates new time-based secret credentials for TURN

View File

@@ -46,12 +46,13 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
turnCredentials, err := tested.GenerateTurnToken()
require.NoError(t, err)
@@ -98,12 +99,13 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -201,12 +203,13 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
tested.SetupRefresh(context.Background(), "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok {

View File

@@ -37,7 +37,6 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -77,7 +76,6 @@ type DefaultAccountManager struct {
ctx context.Context
eventStore activity.Store
geo geolocation.Geolocation
ephemeralManager ephemeral.Manager
requestBuffer *AccountRequestBuffer
@@ -238,7 +236,7 @@ func BuildManager(
log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter)
}
cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
if err != nil {
return nil, fmt.Errorf("getting cache store: %s", err)
}
@@ -263,10 +261,6 @@ func BuildManager(
return am, nil
}
func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) {
am.ephemeralManager = em
}
func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager {
return am.externalCacheManager
}
@@ -2076,7 +2070,10 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
if err != nil {
return err
}
am.networkMapController.OnPeerUpdated(peer.AccountID, peer)
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
return nil
}

View File

@@ -13,7 +13,6 @@ import (
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -124,5 +123,4 @@ type Manager interface {
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
SetEphemeralManager(em ephemeral.Manager)
}

View File

@@ -25,6 +25,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -2959,8 +2961,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{})
manager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, nil, err
}
@@ -3371,7 +3373,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
t.Run("memory cache", func(t *testing.T) {
t.Run("should always return true", func(t *testing.T) {
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
cold, err := manager.isCacheCold(context.Background(), cacheStore)
@@ -3386,7 +3388,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
t.Cleanup(cleanup)
t.Setenv(cache.RedisStoreEnvVar, redisURL)
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
t.Run("should return true when no account exists", func(t *testing.T) {

View File

@@ -18,6 +18,7 @@ const (
DefaultIDPCacheExpirationMax = 7 * 24 * time.Hour // 7 days
DefaultIDPCacheExpirationMin = 3 * 24 * time.Hour // 3 days
DefaultIDPCacheCleanupInterval = 30 * time.Minute
DefaultIDPCacheOpenConn = 100
)
// UserDataCache is an interface that wraps the basic Get, Set and Delete methods for idp.UserData objects.

View File

@@ -33,7 +33,7 @@ func TestNewIDPCacheManagers(t *testing.T) {
t.Cleanup(cleanup)
t.Setenv(cache.RedisStoreEnvVar, redisURL)
}
cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval)
cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval, cache.DefaultIDPCacheOpenConn)
if err != nil {
t.Fatalf("couldn't create cache store: %s", err)
}

View File

@@ -3,6 +3,7 @@ package cache
import (
"context"
"fmt"
"math"
"os"
"time"
@@ -20,24 +21,27 @@ const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS"
// NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar
// to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store.
func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration) (store.StoreInterface, error) {
func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) {
redisAddr := os.Getenv(RedisStoreEnvVar)
if redisAddr != "" {
return getRedisStore(ctx, redisAddr)
return getRedisStore(ctx, redisAddr, maxConn)
}
goc := gocache.New(maxTimeout, cleanupInterval)
return gocache_store.NewGoCache(goc), nil
}
func getRedisStore(ctx context.Context, redisEnvAddr string) (store.StoreInterface, error) {
func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) {
options, err := redis.ParseURL(redisEnvAddr)
if err != nil {
return nil, fmt.Errorf("parsing redis cache url: %s", err)
}
options.MaxIdleConns = 6
options.MinIdleConns = 3
options.MaxActiveConns = 100
options.MaxIdleConns = int(math.Ceil(float64(maxConn) * 0.5)) // 50% of max conns
options.MinIdleConns = int(math.Ceil(float64(maxConn) * 0.1)) // 10% of max conns
options.MaxActiveConns = maxConn
options.ConnMaxIdleTime = 30 * time.Minute
options.ConnMaxLifetime = 0
options.PoolTimeout = 10 * time.Second
redisClient := redis.NewClient(options)
subCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()

View File

@@ -15,7 +15,7 @@ import (
)
func TestMemoryStore(t *testing.T) {
memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatalf("couldn't create memory store: %s", err)
}
@@ -42,7 +42,7 @@ func TestMemoryStore(t *testing.T) {
func TestRedisStoreConnectionFailure(t *testing.T) {
t.Setenv(cache.RedisStoreEnvVar, "redis://127.0.0.1:6379")
_, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond)
_, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond, 100)
if err == nil {
t.Fatal("getting redis cache store should return error")
}
@@ -65,7 +65,7 @@ func TestRedisStoreConnectionSuccess(t *testing.T) {
}
t.Setenv(cache.RedisStoreEnvVar, redisURL)
redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond)
redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatalf("couldn't create redis store: %s", err)
}

View File

@@ -12,6 +12,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
@@ -223,7 +225,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
nbgroups "github.com/netbirdio/netbird/management/server/groups"
@@ -39,7 +40,6 @@ import (
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
nbpeers "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/telemetry"
)

View File

@@ -45,19 +45,6 @@ func NewHandler(accountManager account.Manager, networkMapController network_map
}
}
func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
peerToReturn := peer.Copy()
if peer.Status.Connected {
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
// This may happen after server restart when not all peers are yet connected
if !h.networkMapController.IsConnected(peer.ID) {
peerToReturn.Status.Connected = false
}
}
return peerToReturn, nil
}
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
if err != nil {
@@ -65,11 +52,6 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
return
}
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(ctx, err, w)
return
}
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(ctx, err, w)
@@ -91,7 +73,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
_, valid := validPeers[peer.ID]
reason := invalidPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason))
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -237,13 +219,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0))
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
}
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)

View File

@@ -109,14 +109,6 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
GetDNSDomain(gomock.Any()).
Return("domain").
AnyTimes()
networkMapController.EXPECT().
IsConnected(noUpdateChannelTestPeerID).
Return(false).
AnyTimes()
networkMapController.EXPECT().
IsConnected(gomock.Any()).
Return(true).
AnyTimes()
return &Handler{
accountManager: &mock_server.MockAccountManager{
@@ -269,14 +261,6 @@ func TestGetPeers(t *testing.T) {
expectedArray: false,
expectedPeer: peer,
},
{
name: "GetPeer with no update channel",
requestType: http.MethodGet,
requestPath: "/api/peers/" + peer1.ID,
expectedStatus: http.StatusOK,
expectedArray: false,
expectedPeer: expectedPeer1,
},
{
name: "PutPeer",
requestType: http.MethodPut,
@@ -336,8 +320,6 @@ func TestGetPeers(t *testing.T) {
for _, peer := range respBody {
if peer.Id == testPeerID {
got = peer
} else {
assert.Equal(t, peer.Connected, false)
}
}
@@ -351,14 +333,14 @@ func TestGetPeers(t *testing.T) {
t.Log(got)
assert.Equal(t, got.Name, tc.expectedPeer.Name)
assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion)
assert.Equal(t, got.Ip, tc.expectedPeer.IP.String())
assert.Equal(t, got.Os, "OS core")
assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled)
assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled)
assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected)
assert.Equal(t, got.SerialNumber, tc.expectedPeer.Meta.SystemSerialNumber)
assert.Equal(t, tc.expectedPeer.Name, got.Name)
assert.Equal(t, tc.expectedPeer.Meta.WtVersion, got.Version)
assert.Equal(t, tc.expectedPeer.IP.String(), got.Ip)
assert.Equal(t, "OS core", got.Os)
assert.Equal(t, tc.expectedPeer.LoginExpirationEnabled, got.LoginExpirationEnabled)
assert.Equal(t, tc.expectedPeer.SSHEnabled, got.SshEnabled)
assert.Equal(t, tc.expectedPeer.Status.Connected, got.Connected)
assert.Equal(t, tc.expectedPeer.Meta.SystemSerialNumber, got.SerialNumber)
})
}
}

View File

@@ -15,6 +15,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server"
@@ -28,7 +30,6 @@ import (
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -72,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
ctx := context.Background()
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)

View File

@@ -24,13 +24,14 @@ import (
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -363,7 +364,9 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager))
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config)
accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
@@ -372,10 +375,13 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
return nil, nil, "", cleanup, err
}
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
cleanup()
return nil, nil, "", cleanup, err
}
ephemeralMgr := manager.NewEphemeralManager(store, accountManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}, networkMapController)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, nil, "", cleanup, err
}

View File

@@ -22,13 +22,14 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -205,7 +206,7 @@ func startServer(
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(ctx, str)
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config)
accountManager, err := server.BuildManager(
context.Background(),
@@ -228,15 +229,16 @@ func startServer(
}
groupsManager := groups.NewManager(str, permissionsManager, accountManager)
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
t.Fatalf("failed creating secrets manager: %v", err)
}
mgmtServer, err := nbgrpc.NewServer(
config,
accountManager,
settingsMockManager,
updateManager,
secretsManager,
nil,
&manager.EphemeralManager{},
nil,
server.MockIntegratedValidator{},
networkMapController,

View File

@@ -15,7 +15,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -976,11 +975,6 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth a
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
}
// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface
func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) {
// Mock implementation - does nothing
}
func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
if am.AllowSyncFunc != nil {
return am.AllowSyncFunc(key, hash)

View File

@@ -13,6 +13,8 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -792,7 +794,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}

View File

@@ -136,7 +136,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if expired {
am.networkMapController.OnPeerUpdated(accountID, peer)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
return nil
@@ -309,7 +312,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
am.networkMapController.OnPeerUpdated(accountID, peer)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
return peer, nil
}
@@ -365,13 +371,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent()
}
err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err)
}
if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peerID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
if err := am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err)
}
return nil
@@ -583,11 +584,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if temporary {
// we are running the on disconnect handler so that it is considered not connected as we are adding the peer manually
am.ephemeralManager.OnPeerDisconnected(ctx, newPeer)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
@@ -645,7 +641,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if err := am.networkMapController.OnPeerAdded(ctx, accountID, newPeer.ID); err != nil {
if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
@@ -729,7 +725,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
am.networkMapController.OnPeerUpdated(accountID, peer)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
@@ -857,7 +856,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
am.networkMapController.OnPeerUpdated(accountID, peer)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)

View File

@@ -28,6 +28,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
@@ -1058,6 +1060,7 @@ func testUpdateAccountPeers(t *testing.T) {
for _, channel := range peerChannels {
update := <-channel
assert.Nil(t, update.Update.NetbirdConfig)
assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
}
@@ -1290,7 +1293,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
@@ -1375,7 +1378,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
@@ -1528,7 +1531,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
@@ -1608,7 +1611,7 @@ func Test_LoginPeer(t *testing.T) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)

View File

@@ -1,68 +0,0 @@
package peers
//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type Manager interface {
GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
GetPeerAccountID(ctx context.Context, peerID string) (string, error)
GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
}
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
}
func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
return &managerImpl{
store: store,
permissionsManager: permissionsManager,
}
}
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
}
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
}
func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
}

View File

@@ -16,6 +16,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -1291,7 +1293,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {

View File

@@ -263,15 +263,11 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return err
}
updateAccountPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo)
_, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo)
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
@@ -998,14 +994,17 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
peer.UserID, peer.ID, accountID,
activity.PeerLoginExpired, peer.EventMeta(dnsDomain),
)
}
am.networkMapController.OnPeerUpdated(accountID, peer)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID)
am.networkMapController.DisconnectPeers(ctx, peerIDs)
am.networkMapController.DisconnectPeers(ctx, accountID, peerIDs)
}
return nil
}
@@ -1051,7 +1050,6 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
}
var allErrors error
var updateAccountPeers bool
for _, targetUserID := range targetUserIDs {
if initiatorUserID == targetUserID {
@@ -1082,19 +1080,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
continue
}
userHadPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo)
_, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
if userHadPeers {
updateAccountPeers = true
}
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return allErrors
@@ -1152,15 +1142,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
return false, err
}
var peerIDs []string
for _, peer := range userPeers {
err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err)
}
if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peer.ID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peer.ID, err)
}
peerIDs = append(peerIDs, peer.ID)
}
if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err)
}
for _, addPeerRemovedEvent := range addPeerRemovedEvents {

View File

@@ -8,8 +8,10 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"go.uber.org/mock/gomock"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
@@ -547,7 +549,7 @@ func TestUser_InviteNewUser(t *testing.T) {
permissionsManager: permissionsManager,
}
cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
require.NoError(t, err)
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cs)
@@ -739,11 +741,18 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil)
permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsManager,
Store: store,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsManager,
networkMapController: networkMapControllerMock,
}
testCases := []struct {
@@ -848,12 +857,20 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
permissionsManager := permissions.NewManager(store)
am := DefaultAccountManager{
Store: store,
eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MockIntegratedValidator{},
permissionsManager: permissionsManager,
networkMapController: networkMapControllerMock,
}
testCases := []struct {
@@ -1056,7 +1073,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
permissionsManager: permissionsManager,
}
cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval)
cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
assert.NoError(t, err)
am.externalCacheManager = nbcache.NewUserDataCache(cacheStore)
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore)
@@ -1412,7 +1429,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
t.Run("deleting user with no linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()

View File

@@ -21,6 +21,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/client/system"
@@ -31,8 +33,6 @@ import (
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -117,7 +117,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManger), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
@@ -125,8 +125,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
groupsManager := groups.NewManagerMock()
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}, networkMapController)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController)
if err != nil {
t.Fatal(err)
}