Compare commits

...

9 Commits

Author SHA1 Message Date
pascal
5c049f6f09 delete tests 2026-02-24 16:53:49 +01:00
pascal
740c726a78 extract proxy controller 2026-02-23 15:28:20 +01:00
pascal
3af287ebab move service manager 2026-02-20 01:21:05 +01:00
pascal
d4d885d434 Merge remote-tracking branch 'origin/main' into feature/add-serial-to-proxy 2026-02-20 00:35:10 +01:00
pascal
d212332f5d store proxies in DB 2026-02-20 00:28:45 +01:00
pascal
0e11258e97 fix test 2026-02-19 13:30:26 +01:00
pascal
31ecf8f1f5 allow redis for token store 2026-02-19 13:28:36 +01:00
pascal
e2df1fb35e add accountID when sending update to cluster 2026-02-19 12:02:31 +01:00
pascal
942cd5dc72 export methods 2026-02-19 10:50:32 +01:00
47 changed files with 949 additions and 3827 deletions

View File

@@ -27,21 +27,21 @@ type store interface {
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
}
type proxyURLProvider interface {
GetConnectedProxyURLs() []string
type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
}
type Manager struct {
store store
validator domain.Validator
proxyURLProvider proxyURLProvider
proxyManager proxyManager
permissionsManager permissions.Manager
}
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager {
return Manager{
store: store,
proxyURLProvider: proxyURLProvider,
store: store,
proxyManager: proxyMgr,
validator: domain.Validator{
Resolver: net.DefaultResolver,
},
@@ -67,8 +67,12 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList := m.proxyURLAllowList()
log.WithFields(log.Fields{
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
log.WithContext(ctx).WithFields(log.Fields{
"accountID": accountID,
"proxyAllowList": allowList,
}).Debug("getting domains with proxy allow list")
@@ -107,7 +111,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
}
// Verify the target cluster is in the available clusters
allowList := m.proxyURLAllowList()
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
clusterValid := false
for _, cluster := range allowList {
if cluster == targetCluster {
@@ -221,21 +228,14 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
}
}
// proxyURLAllowList retrieves a list of currently connected proxies and
// their URLs
func (m Manager) proxyURLAllowList() []string {
var reverseProxyAddresses []string
if m.proxyURLProvider != nil {
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
}
return reverseProxyAddresses
}
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList := m.proxyURLAllowList()
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
if len(allowList) == 0 {
return "", fmt.Errorf("no proxy clusters available")
}

View File

@@ -0,0 +1,15 @@
package proxy
import (
"context"
"time"
)
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
}

View File

@@ -0,0 +1,106 @@
package manager
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// store defines the interface for proxy persistence operations
type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
}
// Manager handles all proxy operations
type Manager struct {
store store
}
// NewManager creates a new proxy Manager
func NewManager(store store) Manager {
return Manager{
store: store,
}
}
// Connect registers a new proxy connection in the database
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
"clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return nil
}
// Disconnect marks a proxy as disconnected in the database
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
}).Info("proxy disconnected")
return nil
}
// Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err
}
return nil
}
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
return addresses, nil
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err
}
return nil
}

View File

@@ -0,0 +1,20 @@
package proxy
import "time"
// Proxy represents a reverse proxy instance
type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
CreatedAt time.Time
UpdatedAt time.Time
}
func (Proxy) TableName() string {
return "proxies"
}

View File

@@ -1,9 +1,11 @@
package reverseproxy
package service
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
//go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import (
"context"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Manager interface {
@@ -13,7 +15,7 @@ type Manager interface {
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
ReloadService(ctx context.Context, accountID, serviceID string) error
GetGlobalServices(ctx context.Context) ([]*Service, error)
@@ -21,3 +23,12 @@ type Manager interface {
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
}
// ProxyController is responsible for managing proxy clusters and routing service updates.
type ProxyController interface {
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
GetOIDCValidationConfig() OIDCValidationConfig
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
}

View File

@@ -1,14 +1,15 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go
// Package reverseproxy is a generated GoMock package.
package reverseproxy
// Package service is a generated GoMock package.
package service
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
proto "github.com/netbirdio/netbird/shared/management/proto"
)
// MockManager is a mock of Manager interface.
@@ -196,7 +197,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic
}
// SetStatus mocks base method.
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
ret0, _ := ret[0].(error)
@@ -223,3 +224,94 @@ func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
}
// MockProxyController is a mock of ProxyController interface.
type MockProxyController struct {
ctrl *gomock.Controller
recorder *MockProxyControllerMockRecorder
}
// MockProxyControllerMockRecorder is the mock recorder for MockProxyController.
type MockProxyControllerMockRecorder struct {
mock *MockProxyController
}
// NewMockProxyController creates a new mock instance.
func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController {
mock := &MockProxyController{ctrl: ctrl}
mock.recorder = &MockProxyControllerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder {
return m.recorder
}
// GetOIDCValidationConfig mocks base method.
func (m *MockProxyController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
ret0, _ := ret[0].(OIDCValidationConfig)
return ret0
}
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
func (mr *MockProxyControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockProxyController)(nil).GetOIDCValidationConfig))
}
// GetProxiesForCluster mocks base method.
func (m *MockProxyController) GetProxiesForCluster(clusterAddr string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
ret0, _ := ret[0].([]string)
return ret0
}
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
func (mr *MockProxyControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockProxyController)(nil).GetProxiesForCluster), clusterAddr)
}
// RegisterProxyToCluster mocks base method.
func (m *MockProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
func (mr *MockProxyControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockProxyController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
}
// SendServiceUpdateToCluster mocks base method.
func (m *MockProxyController) SendServiceUpdateToCluster(accountID string, update *proto.ProxyMapping, clusterAddr string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SendServiceUpdateToCluster", accountID, update, clusterAddr)
}
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
func (mr *MockProxyControllerMockRecorder) SendServiceUpdateToCluster(accountID, update, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockProxyController)(nil).SendServiceUpdateToCluster), accountID, update, clusterAddr)
}
// UnregisterProxyFromCluster mocks base method.
func (m *MockProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
func (mr *MockProxyControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockProxyController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
}

View File

@@ -6,10 +6,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -17,11 +17,11 @@ import (
)
type handler struct {
manager reverseproxy.Manager
manager rpservice.Manager
}
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{
manager: manager,
}
@@ -72,7 +72,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
return
}
service := new(reverseproxy.Service)
service := new(rpservice.Service)
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil {
@@ -130,7 +130,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
return
}
service := new(reverseproxy.Service)
service := new(rpservice.Service)
service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId)

View File

@@ -0,0 +1,74 @@
package manager
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/shared/management/proto"
)
// GRPCProxyController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
type GRPCProxyController struct {
proxyGRPCServer *nbgrpc.ProxyServiceServer
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
}
// NewGRPCProxyController creates a new GRPCProxyController.
func NewGRPCProxyController(proxyGRPCServer *nbgrpc.ProxyServiceServer) *GRPCProxyController {
return &GRPCProxyController{
proxyGRPCServer: proxyGRPCServer,
}
}
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
func (c *GRPCProxyController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
}
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
func (c *GRPCProxyController) GetOIDCValidationConfig() rpservice.OIDCValidationConfig {
return c.proxyGRPCServer.GetOIDCValidationConfig()
}
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
func (c *GRPCProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
return nil
}
// UnregisterProxyFromCluster removes a proxy from a cluster.
func (c *GRPCProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
}
return nil
}
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
func (c *GRPCProxyController) GetProxiesForCluster(clusterAddr string) []string {
proxySet, ok := c.clusterProxies.Load(clusterAddr)
if !ok {
return nil
}
var proxies []string
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxies = append(proxies, key.(string))
return true
})
return proxies
}

View File

@@ -7,9 +7,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
@@ -26,26 +25,26 @@ type ClusterDeriver interface {
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
}
type managerImpl struct {
type Manager struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer
proxyController service.ProxyController
clusterDeriver ClusterDeriver
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
return &managerImpl{
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController service.ProxyController, clusterDeriver ClusterDeriver) *Manager {
return &Manager{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyGRPCServer: proxyGRPCServer,
proxyController: proxyController,
clusterDeriver: clusterDeriver,
}
}
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -69,34 +68,34 @@ func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID stri
return services, nil
}
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
for _, target := range service.Targets {
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
for _, target := range s.Targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
case service.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = peer.IP.String()
case reverseproxy.TargetTypeHost:
case service.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Prefix.Addr().String()
case reverseproxy.TargetTypeDomain:
case service.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Domain
case reverseproxy.TargetTypeSubnet:
case service.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
@@ -105,7 +104,7 @@ func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string,
return nil
}
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -126,7 +125,7 @@ func (m *managerImpl) GetService(ctx context.Context, accountID, userID, service
return service, nil
}
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -135,29 +134,29 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
return nil, status.NewPermissionDeniedError()
}
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
return nil, err
}
if err := m.persistNewService(ctx, accountID, service); err != nil {
if err := m.persistNewService(ctx, accountID, s); err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service)
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
return s, nil
}
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
if m.clusterDeriver != nil {
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
@@ -184,7 +183,7 @@ func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID
return nil
}
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
return err
@@ -202,7 +201,7 @@ func (m *managerImpl) persistNewService(ctx context.Context, accountID string, s
})
}
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
@@ -218,7 +217,7 @@ func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction stor
return nil
}
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -242,7 +241,7 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.sendServiceUpdateNotifications(service, updateInfo)
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
@@ -254,7 +253,7 @@ type serviceUpdateInfo struct {
serviceEnabledChanged bool
}
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -292,7 +291,7 @@ func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string
return &updateInfo, err
}
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
return err
}
@@ -309,7 +308,7 @@ func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.
return nil
}
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
@@ -323,40 +322,40 @@ func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reve
}
}
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyController.GetOIDCValidationConfig()
switch {
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
case !service.Enabled && updateInfo.serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
case service.Enabled && updateInfo.serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
case !s.Enabled && updateInfo.serviceEnabledChanged:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
case s.Enabled && updateInfo.serviceEnabledChanged:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
default:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
}
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
for _, target := range targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
case service.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
@@ -368,7 +367,7 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil
}
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -377,10 +376,10 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return status.NewPermissionDeniedError()
}
var service *reverseproxy.Service
var s *service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
@@ -395,9 +394,9 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return err
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
@@ -406,7 +405,7 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued.
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
@@ -424,7 +423,7 @@ func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, ser
}
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
@@ -441,42 +440,42 @@ func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string
})
}
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get services: %w", err)
}
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
for _, s := range services {
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
}
return nil
}
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
@@ -492,7 +491,7 @@ func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Se
return services, nil
}
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
@@ -506,7 +505,7 @@ func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID s
return service, nil
}
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
@@ -522,7 +521,7 @@ func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string)
return services, nil
}
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {

View File

@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -20,13 +20,13 @@ func TestInitializeServiceForCreate(t *testing.T) {
accountID := "test-account"
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
mgr := &managerImpl{
mgr := &Manager{
clusterDeriver: nil,
}
service := &reverseproxy.Service{
service := &rpservice.Service{
Domain: "example.com",
Auth: reverseproxy.AuthConfig{},
Auth: rpservice.AuthConfig{},
}
err := mgr.initializeServiceForCreate(ctx, accountID, service)
@@ -40,12 +40,12 @@ func TestInitializeServiceForCreate(t *testing.T) {
})
t.Run("verifies session keys are different", func(t *testing.T) {
mgr := &managerImpl{
mgr := &Manager{
clusterDeriver: nil,
}
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
service1 := &rpservice.Service{Domain: "test1.com", Auth: rpservice.AuthConfig{}}
service2 := &rpservice.Service{Domain: "test2.com", Auth: rpservice.AuthConfig{}}
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
@@ -87,7 +87,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
},
expectedError: true,
errorType: status.AlreadyExists,
@@ -99,7 +99,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: false,
},
@@ -110,7 +110,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: true,
errorType: status.AlreadyExists,
@@ -136,7 +136,7 @@ func TestCheckDomainAvailable(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
tt.setupMock(mockStore)
mgr := &managerImpl{}
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
if tt.expectedError {
@@ -166,7 +166,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
GetServiceByDomain(ctx, accountID, "").
Return(nil, status.Errorf(status.NotFound, "not found"))
mgr := &managerImpl{}
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
assert.NoError(t, err)
@@ -179,9 +179,9 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "test.com").
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
mgr := &managerImpl{}
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
assert.Error(t, err)
@@ -199,7 +199,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
GetServiceByDomain(ctx, accountID, "nil.com").
Return(nil, nil)
mgr := &managerImpl{}
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
assert.NoError(t, err)
@@ -215,10 +215,10 @@ func TestPersistNewService(t *testing.T) {
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{
service := &rpservice.Service{
ID: "service-123",
Domain: "new.com",
Targets: []*reverseproxy.Target{},
Targets: []*rpservice.Target{},
}
// Mock ExecuteInTransaction to execute the function immediately
@@ -237,7 +237,7 @@ func TestPersistNewService(t *testing.T) {
return fn(txMock)
})
mgr := &managerImpl{store: mockStore}
mgr := &Manager{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service)
assert.NoError(t, err)
@@ -248,10 +248,10 @@ func TestPersistNewService(t *testing.T) {
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{
service := &rpservice.Service{
ID: "service-123",
Domain: "existing.com",
Targets: []*reverseproxy.Target{},
Targets: []*rpservice.Target{},
}
mockStore.EXPECT().
@@ -260,12 +260,12 @@ func TestPersistNewService(t *testing.T) {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "existing.com").
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
return fn(txMock)
})
mgr := &managerImpl{store: mockStore}
mgr := &Manager{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service)
require.Error(t, err)
@@ -275,21 +275,21 @@ func TestPersistNewService(t *testing.T) {
})
}
func TestPreserveExistingAuthSecrets(t *testing.T) {
mgr := &managerImpl{}
mgr := &Manager{}
t.Run("preserve password when empty", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "hashed-password",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "",
},
@@ -302,18 +302,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
})
t.Run("preserve pin when empty", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
PinAuth: &rpservice.PINAuthConfig{
Enabled: true,
Pin: "hashed-pin",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{
updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
PinAuth: &rpservice.PINAuthConfig{
Enabled: true,
Pin: "",
},
@@ -326,18 +326,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
})
t.Run("do not preserve when password is provided", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "old-password",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "new-password",
},
@@ -352,10 +352,10 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
}
func TestPreserveServiceMetadata(t *testing.T) {
mgr := &managerImpl{}
mgr := &Manager{}
existing := &reverseproxy.Service{
Meta: reverseproxy.ServiceMeta{
existing := &rpservice.Service{
Meta: rpservice.ServiceMeta{
CertificateIssuedAt: time.Now(),
Status: "active",
},
@@ -363,7 +363,7 @@ func TestPreserveServiceMetadata(t *testing.T) {
SessionPublicKey: "public-key",
}
updated := &reverseproxy.Service{
updated := &rpservice.Service{
Domain: "updated.com",
}

View File

@@ -1,4 +1,4 @@
package reverseproxy
package service
import (
"errors"
@@ -26,15 +26,15 @@ const (
Delete Operation = "delete"
)
type ProxyStatus string
type Status string
const (
StatusPending ProxyStatus = "pending"
StatusActive ProxyStatus = "active"
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
StatusCertificatePending ProxyStatus = "certificate_pending"
StatusCertificateFailed ProxyStatus = "certificate_failed"
StatusError ProxyStatus = "error"
StatusPending Status = "pending"
StatusActive Status = "active"
StatusTunnelNotCreated Status = "tunnel_not_created"
StatusCertificatePending Status = "certificate_pending"
StatusCertificateFailed Status = "certificate_failed"
StatusError Status = "error"
TargetTypePeer = "peer"
TargetTypeHost = "host"

View File

@@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -134,7 +134,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
if s.Config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
if err != nil {
log.Fatalf("failed to create certificate manager: %v", err)
log.Fatalf("failed to create certificate service: %v", err)
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
@@ -163,9 +163,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
s.AfterInit(func(s *BaseServer) {
proxyService.SetProxyManager(s.ReverseProxyManager())
proxyService.SetProxyManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())
})
return proxyService
})
@@ -188,7 +189,10 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
return Create(s, func() *nbgrpc.OneTimeTokenStore {
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create proxy token store: %v", err)
}
log.Info("One-time token store initialized for proxy authentication")
return tokenStore
})

View File

@@ -6,6 +6,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -106,6 +108,12 @@ func (s *BaseServer) NetworkMapController() network_map.Controller {
})
}
func (s *BaseServer) ServiceProxyController() service.ProxyController {
return Create(s, func() service.ProxyController {
return nbreverseproxy.NewGRPCProxyController(s.ReverseProxyGRPCServer())
})
}
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
return Create(s, func() *server.AccountRequestBuffer {
return server.NewAccountRequestBuffer(context.Background(), s.Store())

View File

@@ -8,9 +8,11 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
@@ -99,11 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
log.Fatalf("failed to create account service: %v", err)
}
s.AfterInit(func(s *BaseServer) {
accountManager.SetServiceManager(s.ReverseProxyManager())
accountManager.SetServiceManager(s.ServiceManager())
})
return accountManager
@@ -114,28 +116,28 @@ func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
// Use embedded IdP manager if embedded Dex is configured and enabled.
// Use embedded IdP service if embedded Dex is configured and enabled.
// Legacy IdpManager won't be used anymore even if configured.
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
if err != nil {
log.Fatalf("failed to create embedded IDP manager: %v", err)
log.Fatalf("failed to create embedded IDP service: %v", err)
}
return idpManager
}
// Fall back to external IdP manager
// Fall back to external IdP service
if s.Config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil {
log.Fatalf("failed to create IDP manager: %v", err)
log.Fatalf("failed to create IDP service: %v", err)
}
}
return idpManager
})
}
// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
return nil
@@ -162,7 +164,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
})
}
@@ -190,15 +192,21 @@ func (s *BaseServer) RecordsManager() records.Manager {
})
}
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
return Create(s, func() reverseproxy.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
func (s *BaseServer) ServiceManager() service.Manager {
return Create(s, func() service.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
})
}
func (s *BaseServer) ProxyManager() proxy.Manager {
return Create(s, func() proxy.Manager {
return proxymanager.NewManager(s.Store())
})
}
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager())
return &m
})
}

View File

@@ -1,202 +0,0 @@
package grpc
import (
"fmt"
"net/netip"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache cache.DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &cache.DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &cache.DNSConfigCache{}
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
}
}
func generateTestData(size int) nbdns.Config {
config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
for i := 0; i < size; i++ {
config.CustomZones[i] = nbdns.CustomZone{
Domain: fmt.Sprintf("domain%d.com", i),
Records: []nbdns.SimpleRecord{
{
Name: fmt.Sprintf("record%d", i),
Type: 1,
Class: "IN",
TTL: 3600,
RData: "192.168.1.1",
},
},
}
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
Port: 53,
NSType: 1,
},
},
}
}
return config
}
func TestBuildJWTConfig_Audiences(t *testing.T) {
tests := []struct {
name string
authAudience string
cliAuthAudience string
expectedAudiences []string
expectedAudience string
}{
{
name: "only_auth_audience",
authAudience: "dashboard-aud",
cliAuthAudience: "",
expectedAudiences: []string{"dashboard-aud"},
expectedAudience: "dashboard-aud",
},
{
name: "both_audiences_different",
authAudience: "dashboard-aud",
cliAuthAudience: "cli-aud",
expectedAudiences: []string{"dashboard-aud", "cli-aud"},
expectedAudience: "cli-aud",
},
{
name: "both_audiences_same",
authAudience: "same-aud",
cliAuthAudience: "same-aud",
expectedAudiences: []string{"same-aud"},
expectedAudience: "same-aud",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
config := &nbconfig.HttpServerConfig{
AuthIssuer: "https://issuer.example.com",
AuthAudience: tc.authAudience,
CLIAuthAudience: tc.cliAuthAudience,
}
result := buildJWTConfig(config, nil)
assert.NotNil(t, result)
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
//nolint:staticcheck // SA1019: Testing backwards compatibility - Audience field must still be populated
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
})
}
}

View File

@@ -1,276 +0,0 @@
package grpc
import (
"hash/fnv"
"math"
"math/rand"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/suite"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
func testAdvancedCfg() *lfConfig {
return &lfConfig{
reconnThreshold: 50 * time.Millisecond,
baseBlockDuration: 100 * time.Millisecond,
reconnLimitForBan: 3,
metaChangeLimit: 2,
}
}
type LoginFilterTestSuite struct {
suite.Suite
filter *loginFilter
}
func (s *LoginFilterTestSuite) SetupTest() {
s.filter = newLoginFilterWithCfg(testAdvancedCfg())
}
func TestLoginFilterTestSuite(t *testing.T) {
suite.Run(t, new(LoginFilterTestSuite))
}
func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.True(s.filter.allowLogin(pubKey, meta))
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(1, s.filter.logged[pubKey].sessionCounter)
}
func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
limit := s.filter.cfg.reconnLimitForBan
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.False(s.filter.allowLogin(pubKey, meta))
s.Require().Contains(s.filter.logged, pubKey)
s.True(s.filter.logged[pubKey].isBanned)
}
func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
limit := s.filter.cfg.reconnLimitForBan
baseBan := s.filter.cfg.baseBlockDuration
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.Require().Contains(s.filter.logged, pubKey)
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(1, s.filter.logged[pubKey].banLevel)
firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond))
s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second)
s.filter.logged[pubKey].isBanned = false
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(2, s.filter.logged[pubKey].banLevel)
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
// nolint
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
}
func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.filter.logged[pubKey] = &peerState{
isBanned: true,
banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)),
}
s.True(s.filter.allowLogin(pubKey, meta))
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.False(s.filter.logged[pubKey].isBanned)
}
func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.filter.logged[pubKey] = &peerState{
currentHash: meta,
banLevel: 3,
lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration),
}
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(0, s.filter.logged[pubKey].banLevel)
}
func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() {
pubKey := "PUB_KEY_A"
limit := s.filter.cfg.metaChangeLimit
for i := range limit {
s.filter.addLogin(pubKey, uint64(i+1))
}
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter)
isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1))
s.False(isAllowed, "should block new meta hash after limit is reached")
}
func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() {
pubKey := "PUB_KEY_A"
meta1 := uint64(1)
meta2 := uint64(2)
meta3 := uint64(3)
s.filter.addLogin(pubKey, meta1)
s.filter.addLogin(pubKey, meta2)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter)
s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window")
s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second))
s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires")
s.filter.addLogin(pubKey, meta3)
s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset")
}
func BenchmarkHashingMethods(b *testing.B) {
meta := nbpeer.PeerSystemMeta{
WtVersion: "1.25.1",
OSVersion: "Ubuntu 22.04.3 LTS",
KernelVersion: "5.15.0-76-generic",
Hostname: "prod-server-database-01",
SystemSerialNumber: "PC-1234567890",
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
}
pubip := "8.8.8.8"
var resultString string
var resultUint uint64
b.Run("BuilderString", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = builderString(meta, pubip)
}
})
b.Run("FnvHashToString", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = fnvHashToString(meta, pubip)
}
})
b.Run("FnvHashToUint64 - used", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultUint = metaHash(meta, pubip)
}
})
_ = resultString
_ = resultUint
}
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
h := fnv.New64a()
if len(meta.NetworkAddresses) != 0 {
for _, na := range meta.NetworkAddresses {
h.Write([]byte(na.Mac))
}
}
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
return strconv.FormatUint(h.Sum64(), 16)
}
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
mac := getMacAddress(meta.NetworkAddresses)
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
len(pubip) + len(mac) + 6
var b strings.Builder
b.Grow(estimatedSize)
b.WriteString(meta.WtVersion)
b.WriteByte('|')
b.WriteString(meta.OSVersion)
b.WriteByte('|')
b.WriteString(meta.KernelVersion)
b.WriteByte('|')
b.WriteString(meta.Hostname)
b.WriteByte('|')
b.WriteString(meta.SystemSerialNumber)
b.WriteByte('|')
b.WriteString(pubip)
return b.String()
}
func getMacAddress(nas []nbpeer.NetworkAddress) string {
if len(nas) == 0 {
return ""
}
macs := make([]string, 0, len(nas))
for _, na := range nas {
macs = append(macs, na.Mac)
}
return strings.Join(macs, "/")
}
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
filter := newLoginFilterWithCfg(testAdvancedCfg())
numKeys := 100000
pubKeys := make([]string, numKeys)
for i := range numKeys {
pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for pb.Next() {
key := pubKeys[r.Intn(numKeys)]
meta := r.Uint64()
if filter.allowLogin(key, meta) {
filter.addLogin(key, meta)
}
}
})
}

View File

@@ -1,28 +1,23 @@
package grpc
import (
"context"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
nbcache "github.com/netbirdio/netbird/management/server/cache"
)
// OneTimeTokenStore manages short-lived, single-use authentication tokens
// for proxy-to-management RPC authentication. Tokens are generated when
// a service is created and must be used exactly once by the proxy
// to authenticate a subsequent RPC call.
type OneTimeTokenStore struct {
tokens map[string]*tokenMetadata
mu sync.RWMutex
cleanup *time.Ticker
cleanupDone chan struct{}
}
// tokenMetadata stores information about a one-time token
type tokenMetadata struct {
ServiceID string
AccountID string
@@ -30,20 +25,24 @@ type tokenMetadata struct {
CreatedAt time.Time
}
// NewOneTimeTokenStore creates a new token store with automatic cleanup
// of expired tokens. The cleanupInterval determines how often expired
// tokens are removed from memory.
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
store := &OneTimeTokenStore{
tokens: make(map[string]*tokenMetadata),
cleanup: time.NewTicker(cleanupInterval),
cleanupDone: make(chan struct{}),
// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
type OneTimeTokenStore struct {
cache *cache.Cache[string]
ctx context.Context
}
// NewOneTimeTokenStore creates a token store with automatic backend selection
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
}
// Start background cleanup goroutine
go store.cleanupExpired()
return store
return &OneTimeTokenStore{
cache: cache.New[string](cacheStore),
ctx: ctx,
}, nil
}
// GenerateToken creates a new cryptographically secure one-time token
@@ -52,25 +51,30 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
//
// Returns the generated token string or an error if random generation fails.
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
// Generate 32 bytes (256 bits) of cryptographically secure random data
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
return "", fmt.Errorf("failed to generate random token: %w", err)
}
// Encode as URL-safe base64 for easy transmission in gRPC
token := base64.URLEncoding.EncodeToString(randomBytes)
hashedToken := hashToken(token)
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[token] = &tokenMetadata{
metadata := &tokenMetadata{
ServiceID: serviceID,
AccountID: accountID,
ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(),
}
metadataJSON, err := json.Marshal(metadata)
if err != nil {
return "", fmt.Errorf("failed to serialize token metadata: %w", err)
}
if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil {
return "", fmt.Errorf("failed to store token: %w", err)
}
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
serviceID, accountID, ttl)
@@ -88,80 +92,45 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
// - Account ID doesn't match
// - Reverse proxy ID doesn't match
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
s.mu.Lock()
defer s.mu.Unlock()
hashedToken := hashToken(token)
metadata, exists := s.tokens[token]
if !exists {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
serviceID, accountID)
metadataJSON, err := s.cache.Get(s.ctx, hashedToken)
if err != nil {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
return fmt.Errorf("invalid token")
}
// Check expiration
metadata := &tokenMetadata{}
if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil {
log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err)
return fmt.Errorf("invalid token metadata")
}
if time.Now().After(metadata.ExpiresAt) {
delete(s.tokens, token)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
serviceID, accountID)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
return fmt.Errorf("token expired")
}
// Validate account ID using constant-time comparison (prevents timing attacks)
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
metadata.AccountID, accountID)
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID)
return fmt.Errorf("account ID mismatch")
}
// Validate service ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
metadata.ServiceID, serviceID)
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID)
return fmt.Errorf("service ID mismatch")
}
// Delete token immediately to enforce single-use
delete(s.tokens, token)
if err := s.cache.Delete(s.ctx, hashedToken); err != nil {
log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
}
log.Infof("Token validated and consumed for proxy %s in account %s",
serviceID, accountID)
log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID)
return nil
}
// cleanupExpired removes expired tokens in the background to prevent memory leaks
func (s *OneTimeTokenStore) cleanupExpired() {
for {
select {
case <-s.cleanup.C:
s.mu.Lock()
now := time.Now()
removed := 0
for token, metadata := range s.tokens {
if now.After(metadata.ExpiresAt) {
delete(s.tokens, token)
removed++
}
}
if removed > 0 {
log.Debugf("Cleaned up %d expired one-time tokens", removed)
}
s.mu.Unlock()
case <-s.cleanupDone:
return
}
}
}
// Close stops the cleanup goroutine and releases resources
func (s *OneTimeTokenStore) Close() {
s.cleanup.Stop()
close(s.cleanupDone)
}
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
func (s *OneTimeTokenStore) GetTokenCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.tokens)
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}

View File

@@ -24,8 +24,9 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
@@ -58,9 +59,6 @@ type ProxyServiceServer struct {
// Map of connected proxies: proxy_id -> proxy connection
connectedProxies sync.Map
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
// Channel for broadcasting reverse proxy updates to all proxies
updatesChan chan *proto.ProxyMapping
@@ -68,7 +66,13 @@ type ProxyServiceServer struct {
accessLogManager accesslogs.Manager
// Manager for reverse proxy operations
reverseProxyManager reverseproxy.Manager
serviceManager rpservice.Manager
// ProxyController for service updates and cluster management
proxyController rpservice.ProxyController
// Manager for proxy connections
proxyManager proxy.Manager
// Manager for peers
peersManager peers.Manager
@@ -107,7 +111,7 @@ type proxyConnection struct {
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
updatesChan: make(chan *proto.ProxyMapping, 100),
@@ -116,9 +120,11 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
tokenStore: tokenStore,
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
pkceCleanupCancel: cancel,
}
go s.cleanupPKCEVerifiers(ctx)
go s.cleanupStaleProxies(ctx)
return s
}
@@ -142,13 +148,33 @@ func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
}
}
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil {
log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err)
}
}
}
}
// Close stops background goroutines.
func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel()
}
func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) {
s.reverseProxyManager = manager
func (s *ProxyServiceServer) SetProxyManager(manager rpservice.Manager) {
s.serviceManager = manager
}
func (s *ProxyServiceServer) SetProxyController(proxyController rpservice.ProxyController) {
s.proxyController = proxyController
}
// GetMappingUpdate handles the control stream with proxy clients
@@ -183,7 +209,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
s.connectedProxies.Store(proxyID, conn)
s.addToCluster(conn.address, proxyID)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
}
log.WithFields(log.Fields{
"proxy_id": proxyID,
"address": proxyAddress,
@@ -191,8 +225,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
s.connectedProxies.Delete(proxyID)
s.removeFromCluster(conn.address, proxyID)
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
cancel()
log.Infof("Proxy %s disconnected", proxyID)
}()
@@ -204,6 +245,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
errChan := make(chan error, 2)
go s.sender(conn, errChan)
// Start heartbeat goroutine
go s.heartbeat(connCtx, proxyID)
select {
case err := <-errChan:
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
@@ -212,10 +256,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
}
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
}
case <-ctx.Done():
return
}
}
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only services matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return fmt.Errorf("get services from store: %w", err)
}
@@ -224,7 +285,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return fmt.Errorf("proxy address is invalid")
}
var filtered []*reverseproxy.Service
var filtered []*rpservice.Service
for _, service := range services {
if !service.Enabled {
continue
@@ -259,7 +320,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{
service.ToProtoMapping(
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy.
rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
token,
s.GetOIDCValidationConfig(),
),
@@ -393,61 +454,43 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
return urls
}
// addToCluster registers a proxy in a cluster.
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
}
// removeFromCluster removes a proxy from a cluster.
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
}
}
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) {
if clusterAddr == "" {
s.SendServiceUpdate(update)
return
}
proxySet, ok := s.clusterProxies.Load(clusterAddr)
if !ok {
log.Debugf("No proxies connected for cluster %s", clusterAddr)
if s.proxyController == nil {
log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr)
return
}
proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr)
if len(proxyIDs) == 0 {
log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr)
return
}
log.Debugf("Sending service update to cluster %s", clusterAddr)
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxyID := key.(string)
for _, proxyID := range proxyIDs {
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
msg := s.perProxyMessage(update, proxyID)
if msg == nil {
return true
continue
}
select {
case conn.sendChan <- msg:
log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
}
}
return true
})
}
}
// perProxyMessage returns a copy of update with a fresh one-time token for
@@ -486,35 +529,8 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
}
}
// GetAvailableClusters returns information about all connected proxy clusters.
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
clusterCounts := make(map[string]int)
s.clusterProxies.Range(func(key, value interface{}) bool {
clusterAddr := key.(string)
proxySet := value.(*sync.Map)
count := 0
proxySet.Range(func(_, _ interface{}) bool {
count++
return true
})
if count > 0 {
clusterCounts[clusterAddr] = count
}
return true
})
clusters := make([]ClusterInfo, 0, len(clusterCounts))
for addr, count := range clusterCounts {
clusters = append(clusters, ClusterInfo{
Address: addr,
ConnectedProxies: count,
})
}
return clusters
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
@@ -533,7 +549,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
}, nil
}
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) {
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *rpservice.Service) (bool, string, proxyauth.Method) {
switch v := req.GetRequest().(type) {
case *proto.AuthenticateRequest_Pin:
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
@@ -544,7 +560,7 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
}
}
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) {
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *rpservice.PINAuthConfig) (bool, string, proxyauth.Method) {
if auth == nil || !auth.Enabled {
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
return false, "", ""
@@ -558,7 +574,7 @@ func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID stri
return true, "pin-user", proxyauth.MethodPIN
}
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) {
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *rpservice.PasswordAuthConfig) (bool, string, proxyauth.Method) {
if auth == nil || !auth.Enabled {
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
return false, "", ""
@@ -580,7 +596,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
}
}
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) {
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
if !authenticated || service.SessionPrivateKey == "" {
return "", nil
}
@@ -620,7 +636,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
}
if certificateIssued {
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err)
}
@@ -632,7 +648,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
internalStatus := protoStatusToInternal(protoStatus)
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to update service status")
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
}
@@ -647,22 +663,22 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
}
// protoStatusToInternal maps proto status to internal status
func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStatus {
func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
switch protoStatus {
case proto.ProxyStatus_PROXY_STATUS_PENDING:
return reverseproxy.StatusPending
return rpservice.StatusPending
case proto.ProxyStatus_PROXY_STATUS_ACTIVE:
return reverseproxy.StatusActive
return rpservice.StatusActive
case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED:
return reverseproxy.StatusTunnelNotCreated
return rpservice.StatusTunnelNotCreated
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING:
return reverseproxy.StatusCertificatePending
return rpservice.StatusCertificatePending
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED:
return reverseproxy.StatusCertificateFailed
return rpservice.StatusCertificateFailed
case proto.ProxyStatus_PROXY_STATUS_ERROR:
return reverseproxy.StatusError
return rpservice.StatusError
default:
return reverseproxy.StatusError
return rpservice.StatusError
}
}
@@ -727,7 +743,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
}
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId())
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
if err != nil {
log.WithContext(ctx).Errorf("failed to get account services: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
@@ -790,8 +806,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig {
// GetOIDCValidationConfig returns the OIDC configuration for token validation
// in the format needed by ToProtoMapping.
func (s *ProxyServiceServer) GetOIDCValidationConfig() reverseproxy.OIDCValidationConfig {
return reverseproxy.OIDCValidationConfig{
func (s *ProxyServiceServer) GetOIDCValidationConfig() rpservice.OIDCValidationConfig {
return rpservice.OIDCValidationConfig{
Issuer: s.oidcConfig.Issuer,
Audiences: []string{s.oidcConfig.Audience},
KeysLocation: s.oidcConfig.KeysLocation,
@@ -850,12 +866,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the service by domain to get its signing key
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return "", fmt.Errorf("get services: %w", err)
}
var service *reverseproxy.Service
var service *rpservice.Service
for _, svc := range services {
if svc.Domain == domain {
service = svc
@@ -921,8 +937,8 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
}
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID)
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
services, err := s.serviceManager.GetAccountServices(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account services: %w", err)
}
@@ -1043,8 +1059,8 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) {
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return nil, fmt.Errorf("get services: %w", err)
}
@@ -1058,7 +1074,7 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri
return nil, fmt.Errorf("service not found for domain: %s", domain)
}
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error {
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil
}

View File

@@ -1,98 +0,0 @@
package grpc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
}
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
ip := "192.168.1.1"
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
}
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure("192.168.1.1")
}
assert.True(t, l.isLimited("192.168.1.1"))
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
}
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
defer l.stop()
ip := "10.0.0.1"
// Exhaust burst
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
require.True(t, l.isLimited(ip))
// Wait for token replenishment
time.Sleep(50 * time.Millisecond)
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
}
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.mu.Lock()
require.Len(t, l.limiters, 1)
// Backdate the entry so it looks stale
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
l.mu.Unlock()
}
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.recordFailure("10.0.0.2")
l.mu.Lock()
// Only backdate one entry
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
assert.Contains(t, l.limiters, "10.0.0.2")
l.mu.Unlock()
}

View File

@@ -1,381 +0,0 @@
package grpc
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/types"
)
type mockReverseProxyManager struct {
proxiesByAccount map[string][]*reverseproxy.Service
err error
}
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
if m.err != nil {
return nil, m.err
}
return m.proxiesByAccount[accountID], nil
}
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
return []*reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
return nil
}
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
return nil
}
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
type mockUsersManager struct {
users map[string]*types.User
err error
}
func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
if m.err != nil {
return nil, m.err
}
user, ok := m.users[userID]
if !ok {
return nil, errors.New("user not found")
}
return user, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
domain string
userID string
proxiesByAccount map[string][]*reverseproxy.Service
users map[string]*types.User
proxyErr error
userErr error
expectErr bool
expectErrMsg string
}{
{
name: "user not found",
domain: "app.example.com",
userID: "unknown-user",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
users: map[string]*types.User{},
expectErr: true,
expectErrMsg: "user not found",
},
{
name: "proxy not found in user's account",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "service not found",
},
{
name: "proxy exists in different account - not accessible",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "service not found",
},
{
name: "no bearer auth configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "bearer auth disabled - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "user not in allowed groups",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}},
},
expectErr: true,
expectErrMsg: "not in allowed groups",
},
{
name: "user in one of the allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}},
},
expectErr: false,
},
{
name: "user in all allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}},
},
expectErr: false,
},
{
name: "proxy manager error",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: nil,
proxyErr: errors.New("database error"),
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "get account services",
},
{
name: "multiple proxies in account - finds correct one",
domain: "app2.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {
{Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
{Domain: "app3.example.com", AccountID: "account1"},
},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.proxyErr,
},
usersManager: &mockUsersManager{
users: tt.users,
err: tt.userErr,
},
}
err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID)
if tt.expectErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.expectErrMsg)
} else {
require.NoError(t, err)
}
})
}
}
func TestGetAccountProxyByDomain(t *testing.T) {
tests := []struct {
name string
accountID string
domain string
proxiesByAccount map[string][]*reverseproxy.Service
err error
expectProxy bool
expectErr bool
}{
{
name: "proxy found",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {
{Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"},
},
},
expectProxy: true,
expectErr: false,
},
{
name: "proxy not found in account",
accountID: "account1",
domain: "unknown.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
expectProxy: false,
expectErr: true,
},
{
name: "empty proxy list for account",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{},
expectProxy: false,
expectErr: true,
},
{
name: "manager error",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: nil,
err: errors.New("database error"),
expectProxy: false,
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.err,
},
}
proxy, err := server.getAccountServiceByDomain(context.Background(), tt.accountID, tt.domain)
if tt.expectErr {
require.Error(t, err)
assert.Nil(t, proxy)
} else {
require.NoError(t, err)
require.NotNil(t, proxy)
assert.Equal(t, tt.domain, proxy.Domain)
}
})
}
}

View File

@@ -1,232 +0,0 @@
package grpc
import (
"crypto/rand"
"encoding/base64"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto"
)
// registerFakeProxy adds a fake proxy connection to the server's internal maps
// and returns the channel where messages will be received.
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
ch := make(chan *proto.ProxyMapping, 10)
conn := &proxyConnection{
proxyID: proxyID,
address: clusterAddr,
sendChan: ch,
}
s.connectedProxies.Store(proxyID, conn)
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
return ch
}
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
select {
case msg := <-ch:
return msg
case <-time.After(time.Second):
return nil
}
}
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
const cluster = "proxy.example.com"
const numProxies = 3
channels := make([]chan *proto.ProxyMapping, numProxies)
for i := range numProxies {
id := "proxy-" + string(rune('a'+i))
channels[i] = registerFakeProxy(s, id, cluster)
}
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
Path: []*proto.PathMapping{
{Path: "/", Target: "http://10.0.0.1:8080/"},
},
}
s.SendServiceUpdateToCluster(update, cluster)
tokens := make([]string, numProxies)
for i, ch := range channels {
msg := drainChannel(ch)
require.NotNil(t, msg, "proxy %d should receive a message", i)
assert.Equal(t, update.Domain, msg.Domain)
assert.Equal(t, update.Id, msg.Id)
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
tokens[i] = msg.AuthToken
}
// All tokens must be unique
tokenSet := make(map[string]struct{})
for i, tok := range tokens {
_, exists := tokenSet[tok]
assert.False(t, exists, "proxy %d got duplicate token", i)
tokenSet[tok] = struct{}{}
}
// Each token must be independently consumable
for i, tok := range tokens {
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
assert.NoError(t, err, "proxy %d token should validate successfully", i)
}
}
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
const cluster = "proxy.example.com"
ch1 := registerFakeProxy(s, "proxy-a", cluster)
ch2 := registerFakeProxy(s, "proxy-b", cluster)
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
s.SendServiceUpdateToCluster(update, cluster)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
// Delete operations should not generate tokens
assert.Empty(t, msg1.AuthToken)
assert.Empty(t, msg2.AuthToken)
// No tokens should have been created
assert.Equal(t, 0, tokenStore.GetTokenCount())
}
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
s.SendServiceUpdate(update)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
assert.NotEmpty(t, msg1.AuthToken)
assert.NotEmpty(t, msg2.AuthToken)
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
// Both tokens should validate
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
}
// generateState creates a state using the same format as GetOIDCURL.
func generateState(s *ProxyServiceServer, redirectURL string) string {
nonce := make([]byte, 16)
_, _ = rand.Read(nonce)
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
payload := redirectURL + "|" + nonceB64
hmacSum := s.generateHMAC(payload)
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
}
func TestOAuthState_NeverTheSame(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
redirectURL := "https://app.example.com/callback"
// Generate 100 states for the same redirect URL
states := make(map[string]bool)
for i := 0; i < 100; i++ {
state := generateState(s, redirectURL)
// State must have 3 parts: base64(url)|nonce|hmac
parts := strings.Split(state, "|")
require.Equal(t, 3, len(parts), "state must have 3 parts")
// State must be unique
require.False(t, states[state], "state %d is a duplicate", i)
states[state] = true
}
}
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
// Old format had only 2 parts: base64(url)|hmac
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err := s.ValidateState("base64url|hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state format")
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
// Store with tampered HMAC
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state signature")
}

View File

@@ -1,108 +0,0 @@
package grpc
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/server/config"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testingServerKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testingClientKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testCases := []struct {
name string
inputFlow *config.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string
expectedComparisonFunc require.ComparisonAssertionFunc
expectedComparisonMSG string
}{
{
name: "Testing No Device Flow Config",
inputFlow: nil,
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Invalid Device Flow Provider Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "NoNe",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Full Device Flow Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "hosted",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &mgmtProto.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.NoError,
expectedErrMSG: "should not return error",
expectedComparisonFunc: require.Equal,
expectedComparisonMSG: "should match",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &Server{
secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey},
config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
key, err := mgmtServer.secretsManager.GetWGKey()
require.NoError(t, err, "should be able to get server key")
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message)
require.NoError(t, err, "should be able to encrypt message")
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
context.TODO(),
&mgmtProto.EncryptedMessage{
WgPubKey: testingClientKey.PublicKey().String(),
Body: encryptedMSG,
},
)
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
if testCase.expectedComparisonFunc != nil {
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
require.NoError(t, err, "should be able to decrypt")
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
}
})
}
}

View File

@@ -1,250 +0,0 @@
package grpc
import (
"context"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"hash"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
var TurnTestHost = &config.Host{
Proto: config.UDP,
URI: "turn:turn.netbird.io:77777",
Username: "username",
Password: "",
}
func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
secret := "some_secret"
peersManager := update_channel.NewPeersUpdateManager(nil)
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
turnCredentials, err := tested.GenerateTurnToken()
require.NoError(t, err)
if turnCredentials.Payload == "" {
t.Errorf("expected generated TURN username not to be empty, got empty")
}
if turnCredentials.Signature == "" {
t.Errorf("expected generated TURN password not to be empty, got empty")
}
validateMAC(t, sha1.New, turnCredentials.Payload, turnCredentials.Signature, []byte(secret))
relayCredentials, err := tested.GenerateRelayToken()
require.NoError(t, err)
if relayCredentials.Payload == "" {
t.Errorf("expected generated relay payload not to be empty, got empty")
}
if relayCredentials.Signature == "" {
t.Errorf("expected generated relay signature not to be empty, got empty")
}
hashedSecret := sha256.Sum256([]byte(secret))
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
}
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
ttl := util.Duration{Duration: 2 * time.Second}
secret := "some_secret"
peersManager := update_channel.NewPeersUpdateManager(nil)
peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer)
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock()
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tested.SetupRefresh(ctx, "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in the turn cancel map, got not present")
}
if _, ok := tested.relayCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in the relay cancel map, got not present")
}
var updates []*network_map.UpdateMessage
loop:
for timeout := time.After(5 * time.Second); ; {
select {
case update := <-updateChannel:
updates = append(updates, update)
case <-timeout:
break loop
}
if len(updates) >= 2 {
break loop
}
}
if len(updates) < 2 {
t.Errorf("expecting at least 2 peer credentials updates, got %v", len(updates))
}
var turnUpdates, relayUpdates int
var firstTurnUpdate, secondTurnUpdate *proto.ProtectedHostConfig
var firstRelayUpdate, secondRelayUpdate *proto.RelayConfig
for _, update := range updates {
if turns := update.Update.GetNetbirdConfig().GetTurns(); len(turns) > 0 {
turnUpdates++
if turnUpdates == 1 {
firstTurnUpdate = turns[0]
} else {
secondTurnUpdate = turns[0]
}
}
if relay := update.Update.GetNetbirdConfig().GetRelay(); relay != nil {
// avoid updating on turn updates since they also send relay credentials
if update.Update.GetNetbirdConfig().GetTurns() == nil {
relayUpdates++
if relayUpdates == 1 {
firstRelayUpdate = relay
} else {
secondRelayUpdate = relay
}
}
}
}
if turnUpdates < 1 {
t.Errorf("expecting at least 1 TURN credential update, got %v", turnUpdates)
}
if relayUpdates < 1 {
t.Errorf("expecting at least 1 relay credential update, got %v", relayUpdates)
}
if firstTurnUpdate != nil && secondTurnUpdate != nil {
if firstTurnUpdate.Password == secondTurnUpdate.Password {
t.Errorf("expecting first TURN credential update password %v to be different from second, got equal", firstTurnUpdate.Password)
}
}
if firstRelayUpdate != nil && secondRelayUpdate != nil {
if firstRelayUpdate.TokenSignature == secondRelayUpdate.TokenSignature {
t.Errorf("expecting first relay credential update signature %v to be different from second, got equal", firstRelayUpdate.TokenSignature)
}
}
}
func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
secret := "some_secret"
peersManager := update_channel.NewPeersUpdateManager(nil)
peer := "some_peer"
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
tested.SetupRefresh(context.Background(), "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in turn cancel map, got not present")
}
if _, ok := tested.relayCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in relay cancel map, got not present")
}
tested.CancelRefresh(peer)
if _, ok := tested.turnCancelMap[peer]; ok {
t.Errorf("expecting peer to be not present in turn cancel map, got present")
}
if _, ok := tested.relayCancelMap[peer]; ok {
t.Errorf("expecting peer to be not present in relay cancel map, got present")
}
}
func validateMAC(t *testing.T, algo func() hash.Hash, username string, actualMAC string, key []byte) {
t.Helper()
mac := hmac.New(algo, key)
_, err := mac.Write([]byte(username))
if err != nil {
t.Fatal(err)
}
expectedMAC := mac.Sum(nil)
decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC)
if err != nil {
t.Fatal(err)
}
equal := hmac.Equal(decodedMAC, expectedMAC)
if !equal {
t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC)
}
}

View File

@@ -1,587 +0,0 @@
package grpc
import (
"testing"
"time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/shared/management/proto"
)
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
shouldSend := debouncer.ProcessUpdate(update)
if !shouldSend {
t.Error("First update should be sent immediately")
}
if debouncer.TimerChannel() == nil {
t.Error("Timer should be started after first update")
}
}
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update should be sent immediately
if !debouncer.ProcessUpdate(update1) {
t.Error("First update should be sent immediately")
}
// Rapid subsequent updates should be coalesced
if debouncer.ProcessUpdate(update2) {
t.Error("Second rapid update should not be sent immediately")
}
if debouncer.ProcessUpdate(update3) {
t.Error("Third rapid update should not be sent immediately")
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update3 {
t.Error("Should get the last update (update3)")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
debouncer.ProcessUpdate(update1)
// Send second update within debounce period
debouncer.ProcessUpdate(update2)
// Wait for timer
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update2 {
t.Error("Should get the last update")
}
if pendingUpdates[0] == update1 {
t.Error("Should not get the first update")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
debouncer.ProcessUpdate(update1)
// Wait a bit, but not the full debounce period
time.Sleep(30 * time.Millisecond)
// Send second update - should reset timer
debouncer.ProcessUpdate(update2)
// Wait a bit more
time.Sleep(30 * time.Millisecond)
// Send third update - should reset timer again
debouncer.ProcessUpdate(update3)
// Now wait for the timer (should fire after last update's reset)
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update3 {
t.Error("Should get the last update (update3)")
}
// Timer should be restarted since there was a pending update
if debouncer.TimerChannel() == nil {
t.Error("Timer should be restarted after sending pending update")
}
case <-time.After(150 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update sent immediately
debouncer.ProcessUpdate(update1)
// Second update coalesced
debouncer.ProcessUpdate(update2)
// Wait for timer to expire
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
t.Fatal("Should have pending update")
}
// After sending pending update, timer is restarted, so next update is NOT immediate
if debouncer.ProcessUpdate(update3) {
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
}
// Wait for the restarted timer and verify update3 is pending
select {
case <-debouncer.TimerChannel():
finalUpdates := debouncer.GetPendingUpdates()
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
t.Error("Should get update3 as pending")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired for restarted timer")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send update to start timer
debouncer.ProcessUpdate(update)
// Stop should clean up
debouncer.Stop()
// Multiple stops should be safe
debouncer.Stop()
}
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate high-frequency updates
var lastUpdate *network_map.UpdateMessage
sentImmediately := 0
for i := 0; i < 100; i++ {
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
lastUpdate = update
if debouncer.ProcessUpdate(update) {
sentImmediately++
}
time.Sleep(1 * time.Millisecond) // Very rapid updates
}
// Only first update should be sent immediately
if sentImmediately != 1 {
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != lastUpdate {
t.Error("Should get the very last update")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
if !debouncer.ProcessUpdate(update) {
t.Error("First update should be sent immediately")
}
// Wait for timer to expire with no additional updates (true quiet period)
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 0 {
t.Error("Should have no pending updates")
}
// After true quiet period, timer should be cleared
if debouncer.TimerChannel() != nil {
t.Error("Timer should be cleared after quiet period")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
updates := make([]*network_map.UpdateMessage, 5)
for i := range updates {
updates[i] = &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
}
// First update sent immediately
debouncer.ProcessUpdate(updates[0])
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
debouncer.ProcessUpdate(updates[1])
debouncer.ProcessUpdate(updates[2])
debouncer.ProcessUpdate(updates[3])
debouncer.ProcessUpdate(updates[4])
// Wait for debounce
<-debouncer.TimerChannel()
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
}
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update sent immediately
if !debouncer.ProcessUpdate(update1) {
t.Error("First update should be sent immediately")
}
// Wait for timer without sending any more updates (true quiet period)
<-debouncer.TimerChannel()
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 0 {
t.Error("Should have no pending updates during quiet period")
}
// After true quiet period, next update should be sent immediately
if !debouncer.ProcessUpdate(update2) {
t.Error("Update after true quiet period should be sent immediately")
}
}
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate continuous high-frequency updates
for i := 0; i < 10; i++ {
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
if i == 0 {
// First one sent immediately
if !debouncer.ProcessUpdate(update) {
t.Error("First update should be sent immediately")
}
} else {
// All others should be coalesced (not sent immediately)
if debouncer.ProcessUpdate(update) {
t.Errorf("Update %d should not be sent immediately", i)
}
}
// Wait a bit but send next update before debounce expires
time.Sleep(20 * time.Millisecond)
}
// Now wait for final debounce
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
t.Fatal("Should have the last update pending")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
netmapUpdate := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
MessageType: network_map.MessageTypeNetworkMap,
}
tokenUpdate1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
tokenUpdate2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
// First update sent immediately
debouncer.ProcessUpdate(netmapUpdate)
// Send multiple control config updates - they should all be queued
debouncer.ProcessUpdate(tokenUpdate1)
debouncer.ProcessUpdate(tokenUpdate2)
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get both control config updates
if len(pendingUpdates) != 2 {
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
}
// Control configs should come first
if pendingUpdates[0] != tokenUpdate1 {
t.Error("First pending update should be tokenUpdate1")
}
if pendingUpdates[1] != tokenUpdate2 {
t.Error("Second pending update should be tokenUpdate2")
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
netmapUpdate1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
MessageType: network_map.MessageTypeNetworkMap,
}
netmapUpdate2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
MessageType: network_map.MessageTypeNetworkMap,
}
tokenUpdate := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
// First update sent immediately
debouncer.ProcessUpdate(netmapUpdate1)
// Send token update and network map update
debouncer.ProcessUpdate(tokenUpdate)
debouncer.ProcessUpdate(netmapUpdate2)
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get 2 updates in order: token, then network map
if len(pendingUpdates) != 2 {
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
}
// Token update should come first (preserves order)
if pendingUpdates[0] != tokenUpdate {
t.Error("First pending update should be tokenUpdate")
}
// Network map update should come second
if pendingUpdates[1] != netmapUpdate2 {
t.Error("Second pending update should be netmapUpdate2")
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate: 50 network maps -> 1 control config -> 50 network maps
// Expected result: 3 messages (netmap, controlConfig, netmap)
// Send first network map immediately
firstNetmap := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
MessageType: network_map.MessageTypeNetworkMap,
}
if !debouncer.ProcessUpdate(firstNetmap) {
t.Error("First update should be sent immediately")
}
// Send 49 more network maps (will be coalesced to last one)
var lastNetmapBatch1 *network_map.UpdateMessage
for i := 1; i < 50; i++ {
lastNetmapBatch1 = &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
MessageType: network_map.MessageTypeNetworkMap,
}
debouncer.ProcessUpdate(lastNetmapBatch1)
}
// Send 1 control config
controlConfig := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
debouncer.ProcessUpdate(controlConfig)
// Send 50 more network maps (will be coalesced to last one)
var lastNetmapBatch2 *network_map.UpdateMessage
for i := 50; i < 100; i++ {
lastNetmapBatch2 = &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
MessageType: network_map.MessageTypeNetworkMap,
}
debouncer.ProcessUpdate(lastNetmapBatch2)
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get exactly 3 updates: netmap, controlConfig, netmap
if len(pendingUpdates) != 3 {
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
}
// First should be the last netmap from batch 1
if pendingUpdates[0] != lastNetmapBatch1 {
t.Error("First pending update should be last netmap from batch 1")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
// Second should be the control config
if pendingUpdates[1] != controlConfig {
t.Error("Second pending update should be control config")
}
// Third should be the last netmap from batch 2
if pendingUpdates[2] != lastNetmapBatch2 {
t.Error("Third pending update should be last netmap from batch 2")
}
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}

View File

@@ -1,304 +0,0 @@
//go:build integration
package grpc
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
type validateSessionTestSetup struct {
proxyService *ProxyServiceServer
store store.Store
cleanup func()
}
func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
t.Helper()
ctx := context.Background()
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
require.NoError(t, err)
proxyManager := &testValidateSessionProxyManager{store: testStore}
usersManager := &testValidateSessionUsersManager{store: testStore}
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
proxyService.SetProxyManager(proxyManager)
createTestProxies(t, ctx, testStore)
return &validateSessionTestSetup{
proxyService: proxyService,
store: testStore,
cleanup: storeCleanup,
}
}
func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
pubKey, privKey := generateSessionKeyPair(t)
testProxy := &reverseproxy.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
Domain: "test-proxy.example.com",
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
},
},
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com",
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
},
}
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
}
func generateSessionKeyPair(t *testing.T) (string, string) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(pub), base64.StdEncoding.EncodeToString(priv)
}
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
t.Helper()
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
require.NoError(t, err)
return token
}
func TestValidateSession_UserAllowed(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.True(t, resp.Valid, "User should be allowed access")
assert.Equal(t, "allowedUserId", resp.UserId)
assert.Empty(t, resp.DeniedReason)
}
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "restricted-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "User not in group should be denied")
assert.Equal(t, "not_in_group", resp.DeniedReason)
assert.Equal(t, "nonGroupUserId", resp.UserId)
}
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "User in different account should be denied")
assert.Equal(t, "account_mismatch", resp.DeniedReason)
}
func TestValidateSession_UserNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Non-existent user should be denied")
assert.Equal(t, "user_not_found", resp.DeniedReason)
}
func TestValidateSession_ProxyNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "unknown-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Unknown proxy should be denied")
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
}
func TestValidateSession_InvalidToken(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: "invalid-token",
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Invalid token should be denied")
assert.Equal(t, "invalid_token", resp.DeniedReason)
}
func TestValidateSession_MissingDomain(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
SessionToken: "some-token",
})
require.NoError(t, err)
assert.False(t, resp.Valid)
assert.Contains(t, resp.DeniedReason, "missing")
}
func TestValidateSession_MissingToken(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
})
require.NoError(t, err)
assert.False(t, resp.Valid)
assert.Contains(t, resp.DeniedReason, "missing")
}
type testValidateSessionProxyManager struct {
store store.Store
}
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
type testValidateSessionUsersManager struct {
store store.Store
}
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
}

View File

@@ -15,7 +15,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth"
@@ -83,9 +83,9 @@ type DefaultAccountManager struct {
requestBuffer *AccountRequestBuffer
proxyController port_forwarding.Controller
settingsManager settings.Manager
reverseProxyManager reverseproxy.Manager
proxyController port_forwarding.Controller
settingsManager settings.Manager
serviceManager service.Manager
// config contains the management server configuration
config *nbconfig.Config
@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
var _ account.Manager = (*DefaultAccountManager)(nil)
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
am.reverseProxyManager = serviceManager
func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) {
am.serviceManager = serviceManager
}
func isUniqueConstraintError(err error) bool {
@@ -394,7 +394,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
}
if reloadReverseProxy {
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
}
}

View File

@@ -6,7 +6,7 @@ import (
"net/netip"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/auth"
nbdns "github.com/netbirdio/netbird/dns"
@@ -140,5 +140,5 @@ type Manager interface {
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
SetServiceManager(serviceManager reverseproxy.Manager)
SetServiceManager(serviceManager service.Manager)
}

View File

@@ -27,8 +27,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
nbAccount "github.com/netbirdio/netbird/management/server/account"

View File

@@ -703,7 +703,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
@@ -73,7 +73,7 @@ const (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -173,8 +173,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router)
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
}
// Register OAuth callback handler for proxy authentication

View File

@@ -18,8 +18,8 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -190,7 +190,8 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
oidcServer := newFakeOIDCServer()
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
usersManager := users.NewManager(testStore)

View File

@@ -12,7 +12,8 @@ import (
"github.com/netbirdio/management-integrations/integrations"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -91,12 +92,16 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
}
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager)
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager)
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
proxyServiceServer.SetProxyManager(reverseProxyManager)
am.SetServiceManager(reverseProxyManager)
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err)
}
proxyMgr := proxymanager.NewManager(store)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
proxyServiceServer.SetProxyManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
@@ -114,7 +119,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -12,7 +12,7 @@ import (
"google.golang.org/grpc/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
@@ -148,7 +148,7 @@ type MockAccountManager struct {
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
}
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
func (am *MockAccountManager) SetServiceManager(serviceManager service.Manager) {
// Mock implementation - no-op
}

View File

@@ -7,7 +7,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
@@ -33,23 +33,23 @@ type Manager interface {
}
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
groupsManager groups.Manager
accountManager account.Manager
reverseProxyManager reverseproxy.Manager
store store.Store
permissionsManager permissions.Manager
groupsManager groups.Manager
accountManager account.Manager
serviceManager service.Manager
}
type mockManager struct {
}
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager {
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager {
return &managerImpl{
store: store,
permissionsManager: permissionsManager,
groupsManager: groupsManager,
accountManager: accountManager,
reverseProxyManager: reverseproxyManager,
store: store,
permissionsManager: permissionsManager,
groupsManager: groupsManager,
accountManager: accountManager,
serviceManager: reverseproxyManager,
}
}
@@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
go func() {
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
err := m.serviceManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
if err != nil {
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
}
@@ -322,7 +322,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
return status.NewPermissionDeniedError()
}
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err)
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -31,8 +31,8 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err)
@@ -54,8 +54,8 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.Error(t, err)
@@ -76,8 +76,8 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.NoError(t, err)
@@ -98,8 +98,8 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.Error(t, err)
@@ -123,8 +123,8 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err)
@@ -147,8 +147,8 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err)
@@ -176,9 +176,9 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.NoError(t, err)
@@ -205,8 +205,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -234,8 +234,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -262,8 +262,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -294,9 +294,9 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.NoError(t, err)
@@ -329,8 +329,8 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -361,8 +361,8 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -392,8 +392,8 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -416,9 +416,9 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
serviceManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err)
@@ -440,8 +440,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err)

View File

@@ -493,7 +493,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var settings *types.Settings
var eventsToStore []func()
serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID)
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err)
}

View File

@@ -28,9 +28,10 @@ import (
"gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -131,8 +132,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -2063,7 +2064,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
return checks, nil
}
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key
@@ -2078,8 +2079,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
return nil, err
}
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) {
var s reverseproxy.Service
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
var s rpservice.Service
var auth []byte
var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
@@ -2109,7 +2110,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
}
}
s.Meta = reverseproxy.ServiceMeta{}
s.Meta = rpservice.ServiceMeta{}
if createdAt.Valid {
s.Meta.CreatedAt = createdAt.Time
}
@@ -2129,7 +2130,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
s.SessionPublicKey = sessionPublicKey.String
}
s.Targets = []*reverseproxy.Target{}
s.Targets = []*rpservice.Target{}
return &s, nil
})
if err != nil {
@@ -2141,7 +2142,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
}
serviceIDs := make([]string, len(services))
serviceMap := make(map[string]*reverseproxy.Service)
serviceMap := make(map[string]*rpservice.Service)
for i, s := range services {
serviceIDs[i] = s.ID
serviceMap[s.ID] = s
@@ -2152,8 +2153,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
return nil, err
}
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) {
var t reverseproxy.Target
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*rpservice.Target, error) {
var t rpservice.Target
var path sql.NullString
err := row.Scan(
&t.ID,
@@ -4825,7 +4826,7 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
return peerID, nil
}
func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
func (s *SqlStore) CreateService(ctx context.Context, service *rpservice.Service) error {
serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err)
@@ -4839,16 +4840,19 @@ func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Serv
return nil
}
func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service) error {
serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err)
}
// Create target type instance outside transaction to avoid variable shadowing
targetType := &rpservice.Target{}
// Use a transaction to ensure atomic updates of the service and its targets
err := s.db.Transaction(func(tx *gorm.DB) error {
// Delete existing targets
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(targetType).Error; err != nil {
return err
}
@@ -4869,7 +4873,7 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Serv
}
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID)
result := s.db.Delete(&rpservice.Service{}, accountAndIDQueryCondition, accountID, serviceID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete service from store")
@@ -4882,13 +4886,13 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
return nil
}
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var service *reverseproxy.Service
var service *rpservice.Service
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -4906,8 +4910,8 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
return service, nil
}
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
var service *reverseproxy.Service
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
var service *rpservice.Service
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -4925,13 +4929,13 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain str
return service, nil
}
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var serviceList []*reverseproxy.Service
var serviceList []*rpservice.Service
result := tx.Find(&serviceList)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
@@ -4947,13 +4951,13 @@ func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength
return serviceList, nil
}
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var serviceList []*reverseproxy.Service
var serviceList []*rpservice.Service
result := tx.Find(&serviceList, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
@@ -5181,13 +5185,13 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces
return query
}
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var target *reverseproxy.Target
var target *rpservice.Target
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -5200,3 +5204,65 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
return target, nil
}
// SaveProxy saves or updates a proxy in the database
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
result := s.db.WithContext(ctx).Save(p)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to save proxy")
}
return nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected").
Update("last_seen", time.Now())
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)
return status.Errorf(status.Internal, "failed to update proxy heartbeat")
}
return nil
}
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
var addresses []string
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses")
}
return addresses, nil
}
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration)
result := s.db.WithContext(ctx).
Where("last_seen < ?", cutoffTime).
Delete(&proxy.Proxy{})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", result.Error)
return status.Errorf(status.Internal, "failed to cleanup stale proxies")
}
if result.RowsAffected > 0 {
log.WithContext(ctx).Infof("Cleaned up %d stale proxies", result.RowsAffected)
}
return nil
}

View File

@@ -20,7 +20,7 @@ import (
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -264,7 +264,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
&types.AccountOnboarding{}, &reverseproxy.Service{}, &reverseproxy.Target{},
&types.AccountOnboarding{}, &service.Service{}, &service.Target{},
}
for i := len(models) - 1; i >= 0; i-- {

View File

@@ -25,9 +25,10 @@ import (
"gorm.io/gorm"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -252,13 +253,13 @@ type Store interface {
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
CreateService(ctx context.Context, service *reverseproxy.Service) error
UpdateService(ctx context.Context, service *reverseproxy.Service) error
CreateService(ctx context.Context, service *rpservice.Service) error
UpdateService(ctx context.Context, service *rpservice.Service) error
DeleteService(ctx context.Context, accountID, serviceID string) error
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
@@ -270,7 +271,12 @@ type Store interface {
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error)
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
}
const (

View File

@@ -12,9 +12,10 @@ import (
gomock "github.com/golang/mock/gomock"
dns "github.com/netbirdio/netbird/dns"
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
domain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
zones "github.com/netbirdio/netbird/management/internals/modules/zones"
records "github.com/netbirdio/netbird/management/internals/modules/zones/records"
types "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -150,6 +151,20 @@ func (mr *MockStoreMockRecorder) ApproveAccountPeers(ctx, accountID interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveAccountPeers", reflect.TypeOf((*MockStore)(nil).ApproveAccountPeers), ctx, accountID)
}
// CleanupStaleProxies mocks base method.
func (m *MockStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CleanupStaleProxies", ctx, inactivityDuration)
ret0, _ := ret[0].(error)
return ret0
}
// CleanupStaleProxies indicates an expected call of CleanupStaleProxies.
func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
}
// Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper()
@@ -293,7 +308,7 @@ func (mr *MockStoreMockRecorder) CreatePolicy(ctx, policy interface{}) *gomock.C
}
// CreateService mocks base method.
func (m *MockStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
func (m *MockStore) CreateService(ctx context.Context, service *service.Service) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateService", ctx, service)
ret0, _ := ret[0].(error)
@@ -1095,10 +1110,10 @@ func (mr *MockStoreMockRecorder) GetAccountRoutes(ctx, lockStrength, accountID i
}
// GetAccountServices mocks base method.
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountServices", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1199,6 +1214,21 @@ func (mr *MockStoreMockRecorder) GetAccountsCounter(ctx interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsCounter", reflect.TypeOf((*MockStore)(nil).GetAccountsCounter), ctx)
}
// GetActiveProxyClusterAddresses mocks base method.
func (m *MockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddresses", ctx)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveProxyClusterAddresses indicates an expected call of GetActiveProxyClusterAddresses.
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
}
// GetAllAccounts mocks base method.
func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account {
m.ctrl.T.Helper()
@@ -1813,10 +1843,10 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
}
// GetServiceByDomain mocks base method.
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
ret0, _ := ret[0].(*reverseproxy.Service)
ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1828,10 +1858,10 @@ func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain inter
}
// GetServiceByID mocks base method.
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByID", ctx, lockStrength, accountID, serviceID)
ret0, _ := ret[0].(*reverseproxy.Service)
ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1843,10 +1873,10 @@ func (mr *MockStoreMockRecorder) GetServiceByID(ctx, lockStrength, accountID, se
}
// GetServiceTargetByTargetID mocks base method.
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*reverseproxy.Target, error) {
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*service.Target, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceTargetByTargetID", ctx, lockStrength, accountID, targetID)
ret0, _ := ret[0].(*reverseproxy.Target)
ret0, _ := ret[0].(*service.Target)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1858,10 +1888,10 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a
}
// GetServices mocks base method.
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServices", ctx, lockStrength)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -2536,6 +2566,20 @@ func (mr *MockStoreMockRecorder) SavePostureChecks(ctx, postureCheck interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockStore)(nil).SavePostureChecks), ctx, postureCheck)
}
// SaveProxy mocks base method.
func (m *MockStore) SaveProxy(ctx context.Context, proxy *proxy.Proxy) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveProxy", ctx, proxy)
ret0, _ := ret[0].(error)
return ret0
}
// SaveProxy indicates an expected call of SaveProxy.
func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
}
// SaveProxyAccessToken mocks base method.
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
m.ctrl.T.Helper()
@@ -2731,8 +2775,22 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups)
}
// UpdateProxyHeartbeat mocks base method.
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID)
}
// UpdateService mocks base method.
func (m *MockStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
func (m *MockStore) UpdateService(ctx context.Context, service *service.Service) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateService", ctx, service)
ret0, _ := ret[0].(error)

View File

@@ -18,7 +18,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -100,7 +100,7 @@ type Account struct {
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"`
Services []*service.Service `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
@@ -906,7 +906,7 @@ func (a *Account) Copy() *Account {
networkResources = append(networkResources, resource.Copy())
}
services := []*reverseproxy.Service{}
services := []*service.Service{}
for _, service := range a.Services {
services = append(services, service.Copy())
}
@@ -1814,7 +1814,7 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) {
}
}
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
for _, target := range service.Targets {
if !target.Enabled {
continue
@@ -1823,7 +1823,7 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *rever
}
}
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) {
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
port, ok := a.resolveTargetPort(ctx, target)
if !ok {
return
@@ -1840,7 +1840,7 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *revers
}
}
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) {
func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) {
if target.Port != 0 {
return target.Port, true
}
@@ -1856,7 +1856,7 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Ta
}
}
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
return &Policy{
ID: policyID,

View File

@@ -1,94 +0,0 @@
package proxy
import (
"context"
"io"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/shared/management/proto"
)
type mockMappingStream struct {
grpc.ClientStream
messages []*proto.GetMappingUpdateResponse
idx int
}
func (m *mockMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
if m.idx >= len(m.messages) {
return nil, io.EOF
}
msg := m.messages[m.idx]
m.idx++
return msg, nil
}
func (m *mockMappingStream) Header() (metadata.MD, error) {
return nil, nil //nolint:nilnil
}
func (m *mockMappingStream) Trailer() metadata.MD { return nil }
func (m *mockMappingStream) CloseSend() error { return nil }
func (m *mockMappingStream) Context() context.Context { return context.Background() }
func (m *mockMappingStream) SendMsg(any) error { return nil }
func (m *mockMappingStream) RecvMsg(any) error { return nil }
func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) {
checker := health.NewChecker(nil, nil)
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{InitialSyncComplete: true},
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "initial sync should be marked done when flag is set")
}
func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) {
checker := health.NewChecker(nil, nil)
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // no sync flag
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.False(t, syncDone, "initial sync should not be marked done without flag")
}
func TestHandleMappingStream_NilHealthChecker(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{InitialSyncComplete: true},
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "sync done flag should be set even without health checker")
}

View File

@@ -1,560 +0,0 @@
package proxy
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
proxytypes "github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// integrationTestSetup contains all real components for testing.
type integrationTestSetup struct {
store store.Store
proxyService *nbgrpc.ProxyServiceServer
grpcServer *grpc.Server
grpcAddr string
cleanup func()
services []*reverseproxy.Service
}
func setupIntegrationTest(t *testing.T) *integrationTestSetup {
t.Helper()
ctx := context.Background()
// Create real SQLite store
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
// Create test account
testAccount := &types.Account{
Id: "test-account-1",
Domain: "test.com",
DomainCategory: "private",
IsDomainPrimaryAccount: true,
CreatedAt: time.Now(),
}
require.NoError(t, testStore.SaveAccount(ctx, testAccount))
// Generate session keys for reverse proxies
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv)
// Create test services in the store
services := []*reverseproxy.Service{
{
ID: "rp-1",
AccountID: "test-account-1",
Name: "Test App 1",
Domain: "app1.test.proxy.io",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "10.0.0.1",
Port: 8080,
Protocol: "http",
TargetId: "peer1",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
ProxyCluster: "test.proxy.io",
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
},
{
ID: "rp-2",
AccountID: "test-account-1",
Name: "Test App 2",
Domain: "app2.test.proxy.io",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "10.0.0.2",
Port: 8080,
Protocol: "http",
TargetId: "peer2",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
ProxyCluster: "test.proxy.io",
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
},
}
for _, svc := range services {
require.NoError(t, testStore.CreateService(ctx, svc))
}
// Create real token store
tokenStore := nbgrpc.NewOneTimeTokenStore(5 * time.Minute)
// Create real users manager
usersManager := users.NewManager(testStore)
// Create real proxy service server with minimal config
oidcConfig := nbgrpc.ProxyOIDCConfig{
Issuer: "https://fake-issuer.example.com",
ClientID: "test-client",
HMACKey: []byte("test-hmac-key"),
}
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
oidcConfig,
nil,
usersManager,
)
// Use store-backed service manager
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
proxyService.SetProxyManager(svcMgr)
// Start real gRPC server
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
grpcServer := grpc.NewServer()
proto.RegisterProxyServiceServer(grpcServer, proxyService)
go func() {
if err := grpcServer.Serve(lis); err != nil {
t.Logf("gRPC server error: %v", err)
}
}()
return &integrationTestSetup{
store: testStore,
proxyService: proxyService,
grpcServer: grpcServer,
grpcAddr: lis.Addr().String(),
services: services,
cleanup: func() {
grpcServer.GracefulStop()
cleanup()
},
}
}
// testAccessLogManager provides access log storage for testing.
type testAccessLogManager struct{}
func (m *testAccessLogManager) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
return 0, nil
}
func (m *testAccessLogManager) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
// noop
}
func (m *testAccessLogManager) StopPeriodicCleanup() {
// noop
}
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
return nil
}
func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, _ *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
return nil, 0, nil
}
// storeBackedServiceManager reads directly from the real store.
type storeBackedServiceManager struct {
store store.Store
tokenStore *nbgrpc.OneTimeTokenStore
}
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
}
func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, errors.New("not implemented")
}
func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, errors.New("not implemented")
}
func (m *storeBackedServiceManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
return nil
}
func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return nil
}
func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
return nil
}
func (m *storeBackedServiceManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
return nil
}
func (m *storeBackedServiceManager) ReloadService(ctx context.Context, accountID, serviceID string) error {
return nil
}
func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1")
}
func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
}
func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *storeBackedServiceManager) GetServiceIDByTargetID(ctx context.Context, accountID string, targetID string) (string, error) {
return "", nil
}
func strPtr(s string) *string {
return &s
}
func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "test-proxy-1",
Version: "test-v1",
Address: "test.proxy.io",
})
require.NoError(t, err)
// Receive all mappings from the snapshot - server sends each mapping individually
mappingsByID := make(map[string]*proto.ProxyMapping)
for i := 0; i < 2; i++ {
msg, err := stream.Recv()
require.NoError(t, err)
for _, m := range msg.GetMapping() {
mappingsByID[m.GetId()] = m
}
}
// Should receive 2 mappings total
assert.Len(t, mappingsByID, 2, "Should receive 2 reverse proxy mappings")
rp1 := mappingsByID["rp-1"]
require.NotNil(t, rp1)
assert.Equal(t, "app1.test.proxy.io", rp1.GetDomain())
assert.Equal(t, "test-account-1", rp1.GetAccountId())
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, rp1.GetType())
assert.NotEmpty(t, rp1.GetAuthToken(), "Should have auth token for peer creation")
rp2 := mappingsByID["rp-2"]
require.NotNil(t, rp2)
assert.Equal(t, "app2.test.proxy.io", rp2.GetDomain())
}
func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
clusterAddress := "test.proxy.io"
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "test-proxy-cluster",
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
mappings := make([]*proto.ProxyMapping, 0)
for i := 0; i < 2; i++ {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
}
// Should receive the 2 mappings matching the cluster
assert.Len(t, mappings, 2, "Should receive mappings for the cluster")
for _, mapping := range mappings {
t.Logf("Received mapping: id=%s domain=%s", mapping.GetId(), mapping.GetDomain())
}
}
func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
clusterAddress := "test.proxy.io"
proxyID := "test-proxy-reconnect"
// Helper to receive all mappings from a stream
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
var mappings []*proto.ProxyMapping
for i := 0; i < count; i++ {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
}
return mappings
}
// First connection
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
firstMappings := receiveMappings(stream1, 2)
cancel1()
time.Sleep(100 * time.Millisecond)
// Second connection (simulating reconnect)
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
secondMappings := receiveMappings(stream2, 2)
// Should receive the same mappings
assert.Equal(t, len(firstMappings), len(secondMappings),
"Should receive same number of mappings on reconnect")
firstIDs := make(map[string]bool)
for _, m := range firstMappings {
firstIDs[m.GetId()] = true
}
for _, m := range secondMappings {
assert.True(t, firstIDs[m.GetId()],
"Mapping %s should be present in both connections", m.GetId())
}
}
func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
// Use real auth middleware and proxy to verify idempotency
logger := log.New()
logger.SetLevel(log.WarnLevel)
authMw := auth.NewMiddleware(logger, nil)
proxyHandler := proxy.NewReverseProxy(nil, "auto", nil, logger)
clusterAddress := "test.proxy.io"
proxyID := "test-proxy-idempotent"
var addMappingCalls atomic.Int32
applyMappings := func(mappings []*proto.ProxyMapping) {
for _, mapping := range mappings {
if mapping.GetType() == proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED {
addMappingCalls.Add(1)
// Apply to real auth middleware (idempotent)
err := authMw.AddDomain(
mapping.GetDomain(),
nil,
"",
0,
mapping.GetAccountId(),
mapping.GetId(),
)
require.NoError(t, err)
// Apply to real proxy (idempotent)
proxyHandler.AddMapping(proxy.Mapping{
Host: mapping.GetDomain(),
ID: mapping.GetId(),
AccountID: proxytypes.AccountID(mapping.GetAccountId()),
})
}
}
}
// Helper to receive and apply all mappings
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
for i := 0; i < 2; i++ {
msg, err := stream.Recv()
require.NoError(t, err)
applyMappings(msg.GetMapping())
}
}
// First connection
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
receiveAndApply(stream1)
cancel1()
firstCallCount := addMappingCalls.Load()
t.Logf("First connection: applied %d mappings", firstCallCount)
time.Sleep(100 * time.Millisecond)
// Second connection
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
receiveAndApply(stream2)
cancel2()
time.Sleep(100 * time.Millisecond)
// Third connection
ctx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel3()
stream3, err := client.GetMappingUpdate(ctx3, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
receiveAndApply(stream3)
totalCalls := addMappingCalls.Load()
t.Logf("After three connections: total applied %d mappings", totalCalls)
// Should have called addMapping 6 times (2 mappings x 3 connections)
// But internal state is NOT duplicated because auth and proxy use maps keyed by domain/host
assert.Equal(t, int32(6), totalCalls, "Should have 6 total calls (2 mappings x 3 connections)")
}
func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
clusterAddress := "test.proxy.io"
var wg sync.WaitGroup
var mu sync.Mutex
receivedByProxy := make(map[string]int)
for i := 1; i <= 3; i++ {
wg.Add(1)
go func(proxyNum int) {
defer wg.Done()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
proxyID := "test-proxy-" + string(rune('A'+proxyNum-1))
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
count := 0
for i := 0; i < 2; i++ {
msg, err := stream.Recv()
require.NoError(t, err)
count += len(msg.GetMapping())
}
mu.Lock()
receivedByProxy[proxyID] = count
mu.Unlock()
}(i)
}
wg.Wait()
for proxyID, count := range receivedByProxy {
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
}
}

View File

@@ -1,106 +0,0 @@
package proxy
import (
"net"
"net/netip"
"testing"
"time"
proxyproto "github.com/pires/go-proxyproto"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWrapProxyProtocol_OverridesRemoteAddr(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")},
ProxyProtocol: true,
}
raw, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer raw.Close()
ln := srv.wrapProxyProtocol(raw)
realClientIP := "203.0.113.50"
realClientPort := uint16(54321)
accepted := make(chan net.Conn, 1)
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
accepted <- conn
}()
// Connect and send a PROXY v2 header.
conn, err := net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
defer conn.Close()
header := &proxyproto.Header{
Version: 2,
Command: proxyproto.PROXY,
TransportProtocol: proxyproto.TCPv4,
SourceAddr: &net.TCPAddr{IP: net.ParseIP(realClientIP), Port: int(realClientPort)},
DestinationAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443},
}
_, err = header.WriteTo(conn)
require.NoError(t, err)
select {
case accepted := <-accepted:
defer accepted.Close()
host, _, err := net.SplitHostPort(accepted.RemoteAddr().String())
require.NoError(t, err)
assert.Equal(t, realClientIP, host, "RemoteAddr should reflect the PROXY header source IP")
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for connection")
}
}
func TestProxyProtocolPolicy_TrustedRequires(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
}
opts := proxyproto.ConnPolicyOptions{
Upstream: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 1234},
}
policy, err := srv.proxyProtocolPolicy(opts)
require.NoError(t, err)
assert.Equal(t, proxyproto.REQUIRE, policy, "trusted source should require PROXY header")
}
func TestProxyProtocolPolicy_UntrustedIgnores(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
}
opts := proxyproto.ConnPolicyOptions{
Upstream: &net.TCPAddr{IP: net.ParseIP("203.0.113.50"), Port: 1234},
}
policy, err := srv.proxyProtocolPolicy(opts)
require.NoError(t, err)
assert.Equal(t, proxyproto.IGNORE, policy, "untrusted source should have PROXY header ignored")
}
func TestProxyProtocolPolicy_InvalidIPRejects(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
}
opts := proxyproto.ConnPolicyOptions{
Upstream: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
}
policy, err := srv.proxyProtocolPolicy(opts)
require.NoError(t, err)
assert.Equal(t, proxyproto.REJECT, policy, "unparsable address should be rejected")
}

View File

@@ -1,48 +0,0 @@
package proxy
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDebugEndpointDisabledByDefault(t *testing.T) {
s := &Server{}
assert.False(t, s.DebugEndpointEnabled, "debug endpoint should be disabled by default")
}
func TestDebugEndpointAddr(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty defaults to localhost",
input: "",
expected: "localhost:8444",
},
{
name: "explicit localhost preserved",
input: "localhost:9999",
expected: "localhost:9999",
},
{
name: "explicit address preserved",
input: "0.0.0.0:8444",
expected: "0.0.0.0:8444",
},
{
name: "127.0.0.1 preserved",
input: "127.0.0.1:8444",
expected: "127.0.0.1:8444",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := debugEndpointAddr(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}

View File

@@ -1,90 +0,0 @@
package proxy
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseTrustedProxies(t *testing.T) {
tests := []struct {
name string
raw string
want []netip.Prefix
wantErr bool
}{
{
name: "empty string returns nil",
raw: "",
want: nil,
},
{
name: "single CIDR",
raw: "10.0.0.0/8",
want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
},
{
name: "single bare IPv4",
raw: "1.2.3.4",
want: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/32")},
},
{
name: "single bare IPv6",
raw: "::1",
want: []netip.Prefix{netip.MustParsePrefix("::1/128")},
},
{
name: "comma-separated CIDRs",
raw: "10.0.0.0/8, 192.168.1.0/24",
want: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "mixed CIDRs and bare IPs",
raw: "10.0.0.0/8, 1.2.3.4, fd00::/8",
want: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("1.2.3.4/32"),
netip.MustParsePrefix("fd00::/8"),
},
},
{
name: "whitespace around entries",
raw: " 10.0.0.0/8 , 192.168.0.0/16 ",
want: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "trailing comma produces no extra entry",
raw: "10.0.0.0/8,",
want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
},
{
name: "invalid entry",
raw: "not-an-ip",
wantErr: true,
},
{
name: "partially invalid",
raw: "10.0.0.0/8, garbage",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseTrustedProxies(tt.raw)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}