mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-11 11:46:28 -04:00
Compare commits
9 Commits
main
...
deploy/pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c049f6f09 | ||
|
|
740c726a78 | ||
|
|
3af287ebab | ||
|
|
d4d885d434 | ||
|
|
d212332f5d | ||
|
|
0e11258e97 | ||
|
|
31ecf8f1f5 | ||
|
|
e2df1fb35e | ||
|
|
942cd5dc72 |
@@ -27,21 +27,21 @@ type store interface {
|
||||
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
||||
}
|
||||
|
||||
type proxyURLProvider interface {
|
||||
GetConnectedProxyURLs() []string
|
||||
type proxyManager interface {
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store
|
||||
validator domain.Validator
|
||||
proxyURLProvider proxyURLProvider
|
||||
proxyManager proxyManager
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
|
||||
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager {
|
||||
return Manager{
|
||||
store: store,
|
||||
proxyURLProvider: proxyURLProvider,
|
||||
store: store,
|
||||
proxyManager: proxyMgr,
|
||||
validator: domain.Validator{
|
||||
Resolver: net.DefaultResolver,
|
||||
},
|
||||
@@ -67,8 +67,12 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
|
||||
// Add connected proxy clusters as free domains.
|
||||
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||
allowList := m.proxyURLAllowList()
|
||||
log.WithFields(log.Fields{
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"accountID": accountID,
|
||||
"proxyAllowList": allowList,
|
||||
}).Debug("getting domains with proxy allow list")
|
||||
@@ -107,7 +111,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
||||
}
|
||||
|
||||
// Verify the target cluster is in the available clusters
|
||||
allowList := m.proxyURLAllowList()
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
clusterValid := false
|
||||
for _, cluster := range allowList {
|
||||
if cluster == targetCluster {
|
||||
@@ -221,21 +228,14 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
|
||||
}
|
||||
}
|
||||
|
||||
// proxyURLAllowList retrieves a list of currently connected proxies and
|
||||
// their URLs
|
||||
func (m Manager) proxyURLAllowList() []string {
|
||||
var reverseProxyAddresses []string
|
||||
if m.proxyURLProvider != nil {
|
||||
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
|
||||
}
|
||||
return reverseProxyAddresses
|
||||
}
|
||||
|
||||
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
|
||||
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||
allowList := m.proxyURLAllowList()
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
if len(allowList) == 0 {
|
||||
return "", fmt.Errorf("no proxy clusters available")
|
||||
}
|
||||
|
||||
15
management/internals/modules/reverseproxy/proxy/manager.go
Normal file
15
management/internals/modules/reverseproxy/proxy/manager.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
Disconnect(ctx context.Context, proxyID string) error
|
||||
Heartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
// store defines the interface for proxy persistence operations
|
||||
type store interface {
|
||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
|
||||
// Manager handles all proxy operations
|
||||
type Manager struct {
|
||||
store store
|
||||
}
|
||||
|
||||
// NewManager creates a new proxy Manager
|
||||
func NewManager(store store) Manager {
|
||||
return Manager{
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect registers a new proxy connection in the database
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
"clusterAddress": clusterAddress,
|
||||
"ipAddress": ipAddress,
|
||||
}).Info("proxy connected")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
Status: "disconnected",
|
||||
DisconnectedAt: &now,
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
}).Info("proxy disconnected")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp
|
||||
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
20
management/internals/modules/reverseproxy/proxy/proxy.go
Normal file
20
management/internals/modules/reverseproxy/proxy/proxy.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package proxy
|
||||
|
||||
import "time"
|
||||
|
||||
// Proxy represents a reverse proxy instance
|
||||
type Proxy struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||
IPAddress string `gorm:"type:varchar(45)"`
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (Proxy) TableName() string {
|
||||
return "proxies"
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
package reverseproxy
|
||||
package service
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
//go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
@@ -13,7 +15,7 @@ type Manager interface {
|
||||
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
||||
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
||||
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
|
||||
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
|
||||
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
||||
ReloadService(ctx context.Context, accountID, serviceID string) error
|
||||
GetGlobalServices(ctx context.Context) ([]*Service, error)
|
||||
@@ -21,3 +23,12 @@ type Manager interface {
|
||||
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
|
||||
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
||||
}
|
||||
|
||||
// ProxyController is responsible for managing proxy clusters and routing service updates.
|
||||
type ProxyController interface {
|
||||
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
|
||||
GetOIDCValidationConfig() OIDCValidationConfig
|
||||
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||
GetProxiesForCluster(clusterAddr string) []string
|
||||
}
|
||||
@@ -1,14 +1,15 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./interface.go
|
||||
|
||||
// Package reverseproxy is a generated GoMock package.
|
||||
package reverseproxy
|
||||
// Package service is a generated GoMock package.
|
||||
package service
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
proto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// MockManager is a mock of Manager interface.
|
||||
@@ -196,7 +197,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic
|
||||
}
|
||||
|
||||
// SetStatus mocks base method.
|
||||
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
|
||||
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -223,3 +224,94 @@ func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
|
||||
}
|
||||
|
||||
// MockProxyController is a mock of ProxyController interface.
|
||||
type MockProxyController struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockProxyControllerMockRecorder
|
||||
}
|
||||
|
||||
// MockProxyControllerMockRecorder is the mock recorder for MockProxyController.
|
||||
type MockProxyControllerMockRecorder struct {
|
||||
mock *MockProxyController
|
||||
}
|
||||
|
||||
// NewMockProxyController creates a new mock instance.
|
||||
func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController {
|
||||
mock := &MockProxyController{ctrl: ctrl}
|
||||
mock.recorder = &MockProxyControllerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig mocks base method.
|
||||
func (m *MockProxyController) GetOIDCValidationConfig() OIDCValidationConfig {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
|
||||
ret0, _ := ret[0].(OIDCValidationConfig)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
|
||||
func (mr *MockProxyControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockProxyController)(nil).GetOIDCValidationConfig))
|
||||
}
|
||||
|
||||
// GetProxiesForCluster mocks base method.
|
||||
func (m *MockProxyController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
|
||||
func (mr *MockProxyControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockProxyController)(nil).GetProxiesForCluster), clusterAddr)
|
||||
}
|
||||
|
||||
// RegisterProxyToCluster mocks base method.
|
||||
func (m *MockProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
|
||||
func (mr *MockProxyControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockProxyController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster mocks base method.
|
||||
func (m *MockProxyController) SendServiceUpdateToCluster(accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SendServiceUpdateToCluster", accountID, update, clusterAddr)
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
|
||||
func (mr *MockProxyControllerMockRecorder) SendServiceUpdateToCluster(accountID, update, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockProxyController)(nil).SendServiceUpdateToCluster), accountID, update, clusterAddr)
|
||||
}
|
||||
|
||||
// UnregisterProxyFromCluster mocks base method.
|
||||
func (m *MockProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
|
||||
func (mr *MockProxyControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockProxyController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
|
||||
}
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
@@ -17,11 +17,11 @@ import (
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager reverseproxy.Manager
|
||||
manager rpservice.Manager
|
||||
}
|
||||
|
||||
// RegisterEndpoints registers all service HTTP endpoints.
|
||||
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
||||
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
service := new(reverseproxy.Service)
|
||||
service := new(rpservice.Service)
|
||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||
|
||||
if err = service.Validate(); err != nil {
|
||||
@@ -130,7 +130,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
service := new(reverseproxy.Service)
|
||||
service := new(rpservice.Service)
|
||||
service.ID = serviceID
|
||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// GRPCProxyController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
|
||||
type GRPCProxyController struct {
|
||||
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||
// Map of cluster address -> set of proxy IDs
|
||||
clusterProxies sync.Map
|
||||
}
|
||||
|
||||
// NewGRPCProxyController creates a new GRPCProxyController.
|
||||
func NewGRPCProxyController(proxyGRPCServer *nbgrpc.ProxyServiceServer) *GRPCProxyController {
|
||||
return &GRPCProxyController{
|
||||
proxyGRPCServer: proxyGRPCServer,
|
||||
}
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
|
||||
func (c *GRPCProxyController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
||||
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
|
||||
func (c *GRPCProxyController) GetOIDCValidationConfig() rpservice.OIDCValidationConfig {
|
||||
return c.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
}
|
||||
|
||||
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
|
||||
func (c *GRPCProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
if clusterAddr == "" {
|
||||
return nil
|
||||
}
|
||||
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnregisterProxyFromCluster removes a proxy from a cluster.
|
||||
func (c *GRPCProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
if clusterAddr == "" {
|
||||
return nil
|
||||
}
|
||||
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
|
||||
proxySet.(*sync.Map).Delete(proxyID)
|
||||
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
|
||||
func (c *GRPCProxyController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
proxySet, ok := c.clusterProxies.Load(clusterAddr)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var proxies []string
|
||||
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
||||
proxies = append(proxies, key.(string))
|
||||
return true
|
||||
})
|
||||
return proxies
|
||||
}
|
||||
@@ -7,9 +7,8 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
@@ -26,26 +25,26 @@ type ClusterDeriver interface {
|
||||
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
type Manager struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||
proxyController service.ProxyController
|
||||
clusterDeriver ClusterDeriver
|
||||
}
|
||||
|
||||
// NewManager creates a new service manager.
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
|
||||
return &managerImpl{
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController service.ProxyController, clusterDeriver ClusterDeriver) *Manager {
|
||||
return &Manager{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
proxyGRPCServer: proxyGRPCServer,
|
||||
proxyController: proxyController,
|
||||
clusterDeriver: clusterDeriver,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -69,34 +68,34 @@ func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID stri
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
for _, target := range service.Targets {
|
||||
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
|
||||
for _, target := range s.Targets {
|
||||
switch target.TargetType {
|
||||
case reverseproxy.TargetTypePeer:
|
||||
case service.TargetTypePeer:
|
||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = peer.IP.String()
|
||||
case reverseproxy.TargetTypeHost:
|
||||
case service.TargetTypeHost:
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = resource.Prefix.Addr().String()
|
||||
case reverseproxy.TargetTypeDomain:
|
||||
case service.TargetTypeDomain:
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = resource.Domain
|
||||
case reverseproxy.TargetTypeSubnet:
|
||||
case service.TargetTypeSubnet:
|
||||
// For subnets we do not do any lookups on the resource
|
||||
default:
|
||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||
@@ -105,7 +104,7 @@ func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -126,7 +125,7 @@ func (m *managerImpl) GetService(ctx context.Context, accountID, userID, service
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -135,29 +134,29 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.persistNewService(ctx, accountID, service); err != nil {
|
||||
if err := m.persistNewService(ctx, accountID, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
|
||||
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
|
||||
if m.clusterDeriver != nil {
|
||||
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
@@ -184,7 +183,7 @@ func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
||||
return err
|
||||
@@ -202,7 +201,7 @@ func (m *managerImpl) persistNewService(ctx context.Context, accountID string, s
|
||||
})
|
||||
}
|
||||
|
||||
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
@@ -218,7 +217,7 @@ func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction stor
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -242,7 +241,7 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdateNotifications(service, updateInfo)
|
||||
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
@@ -254,7 +253,7 @@ type serviceUpdateInfo struct {
|
||||
serviceEnabledChanged bool
|
||||
}
|
||||
|
||||
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
|
||||
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -292,7 +291,7 @@ func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string
|
||||
return &updateInfo, err
|
||||
}
|
||||
|
||||
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -309,7 +308,7 @@ func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
|
||||
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||
service.Auth.PasswordAuth.Password == "" {
|
||||
@@ -323,40 +322,40 @@ func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reve
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
|
||||
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
|
||||
service.Meta = existingService.Meta
|
||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||
service.SessionPublicKey = existingService.SessionPublicKey
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
switch {
|
||||
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster)
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||
case !service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
|
||||
case service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
|
||||
case !s.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
|
||||
case s.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
|
||||
default:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
|
||||
}
|
||||
}
|
||||
|
||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
|
||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
|
||||
for _, target := range targets {
|
||||
switch target.TargetType {
|
||||
case reverseproxy.TargetTypePeer:
|
||||
case service.TargetTypePeer:
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
||||
}
|
||||
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
||||
}
|
||||
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
|
||||
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
|
||||
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
||||
@@ -368,7 +367,7 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
@@ -377,10 +376,10 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var service *reverseproxy.Service
|
||||
var s *service.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -395,9 +394,9 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
|
||||
return err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
|
||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
|
||||
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
@@ -406,7 +405,7 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
|
||||
|
||||
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
|
||||
// Call this when receiving a gRPC notification that the certificate was issued.
|
||||
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
@@ -424,7 +423,7 @@ func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, ser
|
||||
}
|
||||
|
||||
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
||||
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
|
||||
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
@@ -441,42 +440,42 @@ func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string
|
||||
})
|
||||
}
|
||||
|
||||
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||
s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||
func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
for _, s := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
@@ -492,7 +491,7 @@ func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Se
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||
@@ -506,7 +505,7 @@ func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID s
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
@@ -522,7 +521,7 @@ func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string)
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -20,13 +20,13 @@ func TestInitializeServiceForCreate(t *testing.T) {
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
|
||||
mgr := &managerImpl{
|
||||
mgr := &Manager{
|
||||
clusterDeriver: nil,
|
||||
}
|
||||
|
||||
service := &reverseproxy.Service{
|
||||
service := &rpservice.Service{
|
||||
Domain: "example.com",
|
||||
Auth: reverseproxy.AuthConfig{},
|
||||
Auth: rpservice.AuthConfig{},
|
||||
}
|
||||
|
||||
err := mgr.initializeServiceForCreate(ctx, accountID, service)
|
||||
@@ -40,12 +40,12 @@ func TestInitializeServiceForCreate(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("verifies session keys are different", func(t *testing.T) {
|
||||
mgr := &managerImpl{
|
||||
mgr := &Manager{
|
||||
clusterDeriver: nil,
|
||||
}
|
||||
|
||||
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
|
||||
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
|
||||
service1 := &rpservice.Service{Domain: "test1.com", Auth: rpservice.AuthConfig{}}
|
||||
service2 := &rpservice.Service{Domain: "test2.com", Auth: rpservice.AuthConfig{}}
|
||||
|
||||
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
|
||||
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
|
||||
@@ -87,7 +87,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
||||
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
errorType: status.AlreadyExists,
|
||||
@@ -99,7 +99,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
@@ -110,7 +110,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
errorType: status.AlreadyExists,
|
||||
@@ -136,7 +136,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
tt.setupMock(mockStore)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
|
||||
|
||||
if tt.expectedError {
|
||||
@@ -166,7 +166,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
GetServiceByDomain(ctx, accountID, "").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
|
||||
mgr := &managerImpl{}
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
@@ -179,9 +179,9 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "test.com").
|
||||
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
|
||||
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
|
||||
|
||||
assert.Error(t, err)
|
||||
@@ -199,7 +199,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
GetServiceByDomain(ctx, accountID, "nil.com").
|
||||
Return(nil, nil)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
@@ -215,10 +215,10 @@ func TestPersistNewService(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
service := &reverseproxy.Service{
|
||||
service := &rpservice.Service{
|
||||
ID: "service-123",
|
||||
Domain: "new.com",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
Targets: []*rpservice.Target{},
|
||||
}
|
||||
|
||||
// Mock ExecuteInTransaction to execute the function immediately
|
||||
@@ -237,7 +237,7 @@ func TestPersistNewService(t *testing.T) {
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
mgr := &Manager{store: mockStore}
|
||||
err := mgr.persistNewService(ctx, accountID, service)
|
||||
|
||||
assert.NoError(t, err)
|
||||
@@ -248,10 +248,10 @@ func TestPersistNewService(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
service := &reverseproxy.Service{
|
||||
service := &rpservice.Service{
|
||||
ID: "service-123",
|
||||
Domain: "existing.com",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
Targets: []*rpservice.Target{},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
@@ -260,12 +260,12 @@ func TestPersistNewService(t *testing.T) {
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "existing.com").
|
||||
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
||||
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
||||
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
mgr := &Manager{store: mockStore}
|
||||
err := mgr.persistNewService(ctx, accountID, service)
|
||||
|
||||
require.Error(t, err)
|
||||
@@ -275,21 +275,21 @@ func TestPersistNewService(t *testing.T) {
|
||||
})
|
||||
}
|
||||
func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||
mgr := &managerImpl{}
|
||||
mgr := &Manager{}
|
||||
|
||||
t.Run("preserve password when empty", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
existing := &rpservice.Service{
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "hashed-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
updated := &rpservice.Service{
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "",
|
||||
},
|
||||
@@ -302,18 +302,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("preserve pin when empty", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PinAuth: &reverseproxy.PINAuthConfig{
|
||||
existing := &rpservice.Service{
|
||||
Auth: rpservice.AuthConfig{
|
||||
PinAuth: &rpservice.PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "hashed-pin",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PinAuth: &reverseproxy.PINAuthConfig{
|
||||
updated := &rpservice.Service{
|
||||
Auth: rpservice.AuthConfig{
|
||||
PinAuth: &rpservice.PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "",
|
||||
},
|
||||
@@ -326,18 +326,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("do not preserve when password is provided", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
existing := &rpservice.Service{
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "old-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
updated := &rpservice.Service{
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "new-password",
|
||||
},
|
||||
@@ -352,10 +352,10 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPreserveServiceMetadata(t *testing.T) {
|
||||
mgr := &managerImpl{}
|
||||
mgr := &Manager{}
|
||||
|
||||
existing := &reverseproxy.Service{
|
||||
Meta: reverseproxy.ServiceMeta{
|
||||
existing := &rpservice.Service{
|
||||
Meta: rpservice.ServiceMeta{
|
||||
CertificateIssuedAt: time.Now(),
|
||||
Status: "active",
|
||||
},
|
||||
@@ -363,7 +363,7 @@ func TestPreserveServiceMetadata(t *testing.T) {
|
||||
SessionPublicKey: "public-key",
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
updated := &rpservice.Service{
|
||||
Domain: "updated.com",
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package reverseproxy
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -26,15 +26,15 @@ const (
|
||||
Delete Operation = "delete"
|
||||
)
|
||||
|
||||
type ProxyStatus string
|
||||
type Status string
|
||||
|
||||
const (
|
||||
StatusPending ProxyStatus = "pending"
|
||||
StatusActive ProxyStatus = "active"
|
||||
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
|
||||
StatusCertificatePending ProxyStatus = "certificate_pending"
|
||||
StatusCertificateFailed ProxyStatus = "certificate_failed"
|
||||
StatusError ProxyStatus = "error"
|
||||
StatusPending Status = "pending"
|
||||
StatusActive Status = "active"
|
||||
StatusTunnelNotCreated Status = "tunnel_not_created"
|
||||
StatusCertificatePending Status = "certificate_pending"
|
||||
StatusCertificateFailed Status = "certificate_failed"
|
||||
StatusError Status = "error"
|
||||
|
||||
TargetTypePeer = "peer"
|
||||
TargetTypeHost = "host"
|
||||
@@ -1,4 +1,4 @@
|
||||
package reverseproxy
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create certificate manager: %v", err)
|
||||
log.Fatalf("failed to create certificate service: %v", err)
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
||||
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
||||
@@ -163,9 +163,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetProxyManager(s.ReverseProxyManager())
|
||||
proxyService.SetProxyManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
})
|
||||
return proxyService
|
||||
})
|
||||
@@ -188,7 +189,10 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
|
||||
|
||||
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
||||
return Create(s, func() *nbgrpc.OneTimeTokenStore {
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create proxy token store: %v", err)
|
||||
}
|
||||
log.Info("One-time token store initialized for proxy authentication")
|
||||
return tokenStore
|
||||
})
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
@@ -106,6 +108,12 @@ func (s *BaseServer) NetworkMapController() network_map.Controller {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ServiceProxyController() service.ProxyController {
|
||||
return Create(s, func() service.ProxyController {
|
||||
return nbreverseproxy.NewGRPCProxyController(s.ReverseProxyGRPCServer())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
|
||||
return Create(s, func() *server.AccountRequestBuffer {
|
||||
return server.NewAccountRequestBuffer(context.Background(), s.Store())
|
||||
|
||||
@@ -8,9 +8,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
@@ -99,11 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
|
||||
return Create(s, func() account.Manager {
|
||||
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create account manager: %v", err)
|
||||
log.Fatalf("failed to create account service: %v", err)
|
||||
}
|
||||
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
accountManager.SetServiceManager(s.ReverseProxyManager())
|
||||
accountManager.SetServiceManager(s.ServiceManager())
|
||||
})
|
||||
|
||||
return accountManager
|
||||
@@ -114,28 +116,28 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return Create(s, func() idp.Manager {
|
||||
var idpManager idp.Manager
|
||||
var err error
|
||||
// Use embedded IdP manager if embedded Dex is configured and enabled.
|
||||
// Use embedded IdP service if embedded Dex is configured and enabled.
|
||||
// Legacy IdpManager won't be used anymore even if configured.
|
||||
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
||||
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create embedded IDP manager: %v", err)
|
||||
log.Fatalf("failed to create embedded IDP service: %v", err)
|
||||
}
|
||||
return idpManager
|
||||
}
|
||||
|
||||
// Fall back to external IdP manager
|
||||
// Fall back to external IdP service
|
||||
if s.Config.IdpManagerConfig != nil {
|
||||
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create IDP manager: %v", err)
|
||||
log.Fatalf("failed to create IDP service: %v", err)
|
||||
}
|
||||
}
|
||||
return idpManager
|
||||
})
|
||||
}
|
||||
|
||||
// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil
|
||||
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
|
||||
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
|
||||
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
|
||||
return nil
|
||||
@@ -162,7 +164,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
|
||||
|
||||
func (s *BaseServer) ResourcesManager() resources.Manager {
|
||||
return Create(s, func() resources.Manager {
|
||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
|
||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -190,15 +192,21 @@ func (s *BaseServer) RecordsManager() records.Manager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
|
||||
return Create(s, func() reverseproxy.Manager {
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
|
||||
func (s *BaseServer) ServiceManager() service.Manager {
|
||||
return Create(s, func() service.Manager {
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ProxyManager() proxy.Manager {
|
||||
return Create(s, func() proxy.Manager {
|
||||
return proxymanager.NewManager(s.Store())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
return Create(s, func() *manager.Manager {
|
||||
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
|
||||
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager())
|
||||
return &m
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
)
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
var cache cache.DNSConfigCache
|
||||
|
||||
// Create two different configs
|
||||
config1 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.com",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group1",
|
||||
Name: "Group 1",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config2 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.org",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group2",
|
||||
Name: "Group 2",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
|
||||
}
|
||||
|
||||
// Verify that result2 is different from result1 and result3
|
||||
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
|
||||
t.Errorf("Results should be different for different inputs")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group1'")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group2"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group2'")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
sizes := []int{10, 100, 1000}
|
||||
|
||||
for _, size := range sizes {
|
||||
testData := generateTestData(size)
|
||||
|
||||
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
|
||||
cache := &cache.DNSConfigCache{}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &cache.DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateTestData(size int) nbdns.Config {
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: make([]nbdns.CustomZone, size),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, size),
|
||||
}
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
config.CustomZones[i] = nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("domain%d.com", i),
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: fmt.Sprintf("record%d", i),
|
||||
Type: 1,
|
||||
Class: "IN",
|
||||
TTL: 3600,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
||||
ID: fmt.Sprintf("group%d", i),
|
||||
Primary: i == 0,
|
||||
Domains: []string{fmt.Sprintf("domain%d.com", i)},
|
||||
SearchDomainsEnabled: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
Port: 53,
|
||||
NSType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func TestBuildJWTConfig_Audiences(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
authAudience string
|
||||
cliAuthAudience string
|
||||
expectedAudiences []string
|
||||
expectedAudience string
|
||||
}{
|
||||
{
|
||||
name: "only_auth_audience",
|
||||
authAudience: "dashboard-aud",
|
||||
cliAuthAudience: "",
|
||||
expectedAudiences: []string{"dashboard-aud"},
|
||||
expectedAudience: "dashboard-aud",
|
||||
},
|
||||
{
|
||||
name: "both_audiences_different",
|
||||
authAudience: "dashboard-aud",
|
||||
cliAuthAudience: "cli-aud",
|
||||
expectedAudiences: []string{"dashboard-aud", "cli-aud"},
|
||||
expectedAudience: "cli-aud",
|
||||
},
|
||||
{
|
||||
name: "both_audiences_same",
|
||||
authAudience: "same-aud",
|
||||
cliAuthAudience: "same-aud",
|
||||
expectedAudiences: []string{"same-aud"},
|
||||
expectedAudience: "same-aud",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := &nbconfig.HttpServerConfig{
|
||||
AuthIssuer: "https://issuer.example.com",
|
||||
AuthAudience: tc.authAudience,
|
||||
CLIAuthAudience: tc.cliAuthAudience,
|
||||
}
|
||||
|
||||
result := buildJWTConfig(config, nil)
|
||||
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
|
||||
//nolint:staticcheck // SA1019: Testing backwards compatibility - Audience field must still be populated
|
||||
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,276 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func testAdvancedCfg() *lfConfig {
|
||||
return &lfConfig{
|
||||
reconnThreshold: 50 * time.Millisecond,
|
||||
baseBlockDuration: 100 * time.Millisecond,
|
||||
reconnLimitForBan: 3,
|
||||
metaChangeLimit: 2,
|
||||
}
|
||||
}
|
||||
|
||||
type LoginFilterTestSuite struct {
|
||||
suite.Suite
|
||||
filter *loginFilter
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) SetupTest() {
|
||||
s.filter = newLoginFilterWithCfg(testAdvancedCfg())
|
||||
}
|
||||
|
||||
func TestLoginFilterTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(LoginFilterTestSuite))
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
|
||||
s.True(s.filter.allowLogin(pubKey, meta))
|
||||
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.Equal(1, s.filter.logged[pubKey].sessionCounter)
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
limit := s.filter.cfg.reconnLimitForBan
|
||||
|
||||
for i := 0; i <= limit; i++ {
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
}
|
||||
|
||||
s.False(s.filter.allowLogin(pubKey, meta))
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.True(s.filter.logged[pubKey].isBanned)
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
limit := s.filter.cfg.reconnLimitForBan
|
||||
baseBan := s.filter.cfg.baseBlockDuration
|
||||
|
||||
for i := 0; i <= limit; i++ {
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
}
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.True(s.filter.logged[pubKey].isBanned)
|
||||
s.Equal(1, s.filter.logged[pubKey].banLevel)
|
||||
firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
|
||||
s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond))
|
||||
|
||||
s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second)
|
||||
s.filter.logged[pubKey].isBanned = false
|
||||
|
||||
for i := 0; i <= limit; i++ {
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
}
|
||||
s.True(s.filter.logged[pubKey].isBanned)
|
||||
s.Equal(2, s.filter.logged[pubKey].banLevel)
|
||||
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
|
||||
// nolint
|
||||
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
|
||||
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
|
||||
s.filter.logged[pubKey] = &peerState{
|
||||
isBanned: true,
|
||||
banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)),
|
||||
}
|
||||
|
||||
s.True(s.filter.allowLogin(pubKey, meta))
|
||||
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.False(s.filter.logged[pubKey].isBanned)
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
|
||||
s.filter.logged[pubKey] = &peerState{
|
||||
currentHash: meta,
|
||||
banLevel: 3,
|
||||
lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration),
|
||||
}
|
||||
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.Equal(0, s.filter.logged[pubKey].banLevel)
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
limit := s.filter.cfg.metaChangeLimit
|
||||
|
||||
for i := range limit {
|
||||
s.filter.addLogin(pubKey, uint64(i+1))
|
||||
}
|
||||
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter)
|
||||
|
||||
isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1))
|
||||
|
||||
s.False(isAllowed, "should block new meta hash after limit is reached")
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta1 := uint64(1)
|
||||
meta2 := uint64(2)
|
||||
meta3 := uint64(3)
|
||||
|
||||
s.filter.addLogin(pubKey, meta1)
|
||||
s.filter.addLogin(pubKey, meta2)
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter)
|
||||
s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window")
|
||||
|
||||
s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second))
|
||||
|
||||
s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires")
|
||||
|
||||
s.filter.addLogin(pubKey, meta3)
|
||||
s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset")
|
||||
}
|
||||
|
||||
func BenchmarkHashingMethods(b *testing.B) {
|
||||
meta := nbpeer.PeerSystemMeta{
|
||||
WtVersion: "1.25.1",
|
||||
OSVersion: "Ubuntu 22.04.3 LTS",
|
||||
KernelVersion: "5.15.0-76-generic",
|
||||
Hostname: "prod-server-database-01",
|
||||
SystemSerialNumber: "PC-1234567890",
|
||||
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
|
||||
}
|
||||
pubip := "8.8.8.8"
|
||||
|
||||
var resultString string
|
||||
var resultUint uint64
|
||||
|
||||
b.Run("BuilderString", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = builderString(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("FnvHashToString", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = fnvHashToString(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("FnvHashToUint64 - used", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultUint = metaHash(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
_ = resultString
|
||||
_ = resultUint
|
||||
}
|
||||
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
h := fnv.New64a()
|
||||
|
||||
if len(meta.NetworkAddresses) != 0 {
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
h.Write([]byte(na.Mac))
|
||||
}
|
||||
}
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
return strconv.FormatUint(h.Sum64(), 16)
|
||||
}
|
||||
|
||||
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
mac := getMacAddress(meta.NetworkAddresses)
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
|
||||
len(pubip) + len(mac) + 6
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(estimatedSize)
|
||||
|
||||
b.WriteString(meta.WtVersion)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.OSVersion)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.KernelVersion)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.Hostname)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.SystemSerialNumber)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(pubip)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func getMacAddress(nas []nbpeer.NetworkAddress) string {
|
||||
if len(nas) == 0 {
|
||||
return ""
|
||||
}
|
||||
macs := make([]string, 0, len(nas))
|
||||
for _, na := range nas {
|
||||
macs = append(macs, na.Mac)
|
||||
}
|
||||
return strings.Join(macs, "/")
|
||||
}
|
||||
|
||||
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
|
||||
filter := newLoginFilterWithCfg(testAdvancedCfg())
|
||||
numKeys := 100000
|
||||
pubKeys := make([]string, numKeys)
|
||||
for i := range numKeys {
|
||||
pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
for pb.Next() {
|
||||
key := pubKeys[r.Intn(numKeys)]
|
||||
meta := r.Uint64()
|
||||
|
||||
if filter.allowLogin(key, meta) {
|
||||
filter.addLogin(key, meta)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,28 +1,23 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/lib/v4/cache"
|
||||
"github.com/eko/gocache/lib/v4/store"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
)
|
||||
|
||||
// OneTimeTokenStore manages short-lived, single-use authentication tokens
|
||||
// for proxy-to-management RPC authentication. Tokens are generated when
|
||||
// a service is created and must be used exactly once by the proxy
|
||||
// to authenticate a subsequent RPC call.
|
||||
type OneTimeTokenStore struct {
|
||||
tokens map[string]*tokenMetadata
|
||||
mu sync.RWMutex
|
||||
cleanup *time.Ticker
|
||||
cleanupDone chan struct{}
|
||||
}
|
||||
|
||||
// tokenMetadata stores information about a one-time token
|
||||
type tokenMetadata struct {
|
||||
ServiceID string
|
||||
AccountID string
|
||||
@@ -30,20 +25,24 @@ type tokenMetadata struct {
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// NewOneTimeTokenStore creates a new token store with automatic cleanup
|
||||
// of expired tokens. The cleanupInterval determines how often expired
|
||||
// tokens are removed from memory.
|
||||
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
||||
store := &OneTimeTokenStore{
|
||||
tokens: make(map[string]*tokenMetadata),
|
||||
cleanup: time.NewTicker(cleanupInterval),
|
||||
cleanupDone: make(chan struct{}),
|
||||
// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
|
||||
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
|
||||
type OneTimeTokenStore struct {
|
||||
cache *cache.Cache[string]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewOneTimeTokenStore creates a token store with automatic backend selection
|
||||
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
|
||||
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cache store: %w", err)
|
||||
}
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go store.cleanupExpired()
|
||||
|
||||
return store
|
||||
return &OneTimeTokenStore{
|
||||
cache: cache.New[string](cacheStore),
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateToken creates a new cryptographically secure one-time token
|
||||
@@ -52,25 +51,30 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
||||
//
|
||||
// Returns the generated token string or an error if random generation fails.
|
||||
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
|
||||
// Generate 32 bytes (256 bits) of cryptographically secure random data
|
||||
randomBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random token: %w", err)
|
||||
}
|
||||
|
||||
// Encode as URL-safe base64 for easy transmission in gRPC
|
||||
token := base64.URLEncoding.EncodeToString(randomBytes)
|
||||
hashedToken := hashToken(token)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.tokens[token] = &tokenMetadata{
|
||||
metadata := &tokenMetadata{
|
||||
ServiceID: serviceID,
|
||||
AccountID: accountID,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
metadataJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to serialize token metadata: %w", err)
|
||||
}
|
||||
|
||||
if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil {
|
||||
return "", fmt.Errorf("failed to store token: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
|
||||
serviceID, accountID, ttl)
|
||||
|
||||
@@ -88,80 +92,45 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
|
||||
// - Account ID doesn't match
|
||||
// - Reverse proxy ID doesn't match
|
||||
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
hashedToken := hashToken(token)
|
||||
|
||||
metadata, exists := s.tokens[token]
|
||||
if !exists {
|
||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
|
||||
serviceID, accountID)
|
||||
metadataJSON, err := s.cache.Get(s.ctx, hashedToken)
|
||||
if err != nil {
|
||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
|
||||
return fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
// Check expiration
|
||||
metadata := &tokenMetadata{}
|
||||
if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil {
|
||||
log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||
return fmt.Errorf("invalid token metadata")
|
||||
}
|
||||
|
||||
if time.Now().After(metadata.ExpiresAt) {
|
||||
delete(s.tokens, token)
|
||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
|
||||
serviceID, accountID)
|
||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
// Validate account ID using constant-time comparison (prevents timing attacks)
|
||||
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
|
||||
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
|
||||
metadata.AccountID, accountID)
|
||||
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID)
|
||||
return fmt.Errorf("account ID mismatch")
|
||||
}
|
||||
|
||||
// Validate service ID using constant-time comparison
|
||||
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
|
||||
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
|
||||
metadata.ServiceID, serviceID)
|
||||
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID)
|
||||
return fmt.Errorf("service ID mismatch")
|
||||
}
|
||||
|
||||
// Delete token immediately to enforce single-use
|
||||
delete(s.tokens, token)
|
||||
if err := s.cache.Delete(s.ctx, hashedToken); err != nil {
|
||||
log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||
}
|
||||
|
||||
log.Infof("Token validated and consumed for proxy %s in account %s",
|
||||
serviceID, accountID)
|
||||
log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupExpired removes expired tokens in the background to prevent memory leaks
|
||||
func (s *OneTimeTokenStore) cleanupExpired() {
|
||||
for {
|
||||
select {
|
||||
case <-s.cleanup.C:
|
||||
s.mu.Lock()
|
||||
now := time.Now()
|
||||
removed := 0
|
||||
for token, metadata := range s.tokens {
|
||||
if now.After(metadata.ExpiresAt) {
|
||||
delete(s.tokens, token)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
if removed > 0 {
|
||||
log.Debugf("Cleaned up %d expired one-time tokens", removed)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
case <-s.cleanupDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup goroutine and releases resources
|
||||
func (s *OneTimeTokenStore) Close() {
|
||||
s.cleanup.Stop()
|
||||
close(s.cleanupDone)
|
||||
}
|
||||
|
||||
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
|
||||
func (s *OneTimeTokenStore) GetTokenCount() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.tokens)
|
||||
func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
@@ -24,8 +24,9 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
@@ -58,9 +59,6 @@ type ProxyServiceServer struct {
|
||||
// Map of connected proxies: proxy_id -> proxy connection
|
||||
connectedProxies sync.Map
|
||||
|
||||
// Map of cluster address -> set of proxy IDs
|
||||
clusterProxies sync.Map
|
||||
|
||||
// Channel for broadcasting reverse proxy updates to all proxies
|
||||
updatesChan chan *proto.ProxyMapping
|
||||
|
||||
@@ -68,7 +66,13 @@ type ProxyServiceServer struct {
|
||||
accessLogManager accesslogs.Manager
|
||||
|
||||
// Manager for reverse proxy operations
|
||||
reverseProxyManager reverseproxy.Manager
|
||||
serviceManager rpservice.Manager
|
||||
|
||||
// ProxyController for service updates and cluster management
|
||||
proxyController rpservice.ProxyController
|
||||
|
||||
// Manager for proxy connections
|
||||
proxyManager proxy.Manager
|
||||
|
||||
// Manager for peers
|
||||
peersManager peers.Manager
|
||||
@@ -107,7 +111,7 @@ type proxyConnection struct {
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer {
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||
@@ -116,9 +120,11 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
tokenStore: tokenStore,
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
pkceCleanupCancel: cancel,
|
||||
}
|
||||
go s.cleanupPKCEVerifiers(ctx)
|
||||
go s.cleanupStaleProxies(ctx)
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -142,13 +148,33 @@ func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
|
||||
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops background goroutines.
|
||||
func (s *ProxyServiceServer) Close() {
|
||||
s.pkceCleanupCancel()
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) {
|
||||
s.reverseProxyManager = manager
|
||||
func (s *ProxyServiceServer) SetProxyManager(manager rpservice.Manager) {
|
||||
s.serviceManager = manager
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) SetProxyController(proxyController rpservice.ProxyController) {
|
||||
s.proxyController = proxyController
|
||||
}
|
||||
|
||||
// GetMappingUpdate handles the control stream with proxy clients
|
||||
@@ -183,7 +209,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
s.addToCluster(conn.address, proxyID)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
// Register proxy in database
|
||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"address": proxyAddress,
|
||||
@@ -191,8 +225,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
s.removeFromCluster(conn.address, proxyID)
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
log.Infof("Proxy %s disconnected", proxyID)
|
||||
}()
|
||||
@@ -204,6 +245,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
// Start heartbeat goroutine
|
||||
go s.heartbeat(connCtx, proxyID)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||
@@ -212,10 +256,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
|
||||
// Only services matching the proxy's cluster address are sent.
|
||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
@@ -224,7 +285,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
|
||||
var filtered []*reverseproxy.Service
|
||||
var filtered []*rpservice.Service
|
||||
for _, service := range services {
|
||||
if !service.Enabled {
|
||||
continue
|
||||
@@ -259,7 +320,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
service.ToProtoMapping(
|
||||
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy.
|
||||
rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
|
||||
token,
|
||||
s.GetOIDCValidationConfig(),
|
||||
),
|
||||
@@ -393,61 +454,43 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
|
||||
return urls
|
||||
}
|
||||
|
||||
// addToCluster registers a proxy in a cluster.
|
||||
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
|
||||
if clusterAddr == "" {
|
||||
return
|
||||
}
|
||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
|
||||
}
|
||||
|
||||
// removeFromCluster removes a proxy from a cluster.
|
||||
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
|
||||
if clusterAddr == "" {
|
||||
return
|
||||
}
|
||||
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
|
||||
proxySet.(*sync.Map).Delete(proxyID)
|
||||
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
|
||||
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
|
||||
// For create/update operations a unique one-time auth token is generated per
|
||||
// proxy so that every replica can independently authenticate with management.
|
||||
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
|
||||
func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) {
|
||||
if clusterAddr == "" {
|
||||
s.SendServiceUpdate(update)
|
||||
return
|
||||
}
|
||||
|
||||
proxySet, ok := s.clusterProxies.Load(clusterAddr)
|
||||
if !ok {
|
||||
log.Debugf("No proxies connected for cluster %s", clusterAddr)
|
||||
if s.proxyController == nil {
|
||||
log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr)
|
||||
return
|
||||
}
|
||||
|
||||
proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr)
|
||||
if len(proxyIDs) == 0 {
|
||||
log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Sending service update to cluster %s", clusterAddr)
|
||||
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
||||
proxyID := key.(string)
|
||||
for _, proxyID := range proxyIDs {
|
||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||
conn := connVal.(*proxyConnection)
|
||||
msg := s.perProxyMessage(update, proxyID)
|
||||
if msg == nil {
|
||||
return true
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
@@ -486,35 +529,8 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableClusters returns information about all connected proxy clusters.
|
||||
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
|
||||
clusterCounts := make(map[string]int)
|
||||
s.clusterProxies.Range(func(key, value interface{}) bool {
|
||||
clusterAddr := key.(string)
|
||||
proxySet := value.(*sync.Map)
|
||||
count := 0
|
||||
proxySet.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
clusterCounts[clusterAddr] = count
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
clusters := make([]ClusterInfo, 0, len(clusterCounts))
|
||||
for addr, count := range clusterCounts {
|
||||
clusters = append(clusters, ClusterInfo{
|
||||
Address: addr,
|
||||
ConnectedProxies: count,
|
||||
})
|
||||
}
|
||||
return clusters
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
|
||||
@@ -533,7 +549,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) {
|
||||
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *rpservice.Service) (bool, string, proxyauth.Method) {
|
||||
switch v := req.GetRequest().(type) {
|
||||
case *proto.AuthenticateRequest_Pin:
|
||||
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
|
||||
@@ -544,7 +560,7 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) {
|
||||
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *rpservice.PINAuthConfig) (bool, string, proxyauth.Method) {
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
|
||||
return false, "", ""
|
||||
@@ -558,7 +574,7 @@ func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID stri
|
||||
return true, "pin-user", proxyauth.MethodPIN
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) {
|
||||
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *rpservice.PasswordAuthConfig) (bool, string, proxyauth.Method) {
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
|
||||
return false, "", ""
|
||||
@@ -580,7 +596,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) {
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
|
||||
if !authenticated || service.SessionPrivateKey == "" {
|
||||
return "", nil
|
||||
}
|
||||
@@ -620,7 +636,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
}
|
||||
|
||||
if certificateIssued {
|
||||
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
|
||||
if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
|
||||
return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err)
|
||||
}
|
||||
@@ -632,7 +648,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
|
||||
internalStatus := protoStatusToInternal(protoStatus)
|
||||
|
||||
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
|
||||
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to update service status")
|
||||
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
|
||||
}
|
||||
@@ -647,22 +663,22 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
}
|
||||
|
||||
// protoStatusToInternal maps proto status to internal status
|
||||
func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStatus {
|
||||
func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
|
||||
switch protoStatus {
|
||||
case proto.ProxyStatus_PROXY_STATUS_PENDING:
|
||||
return reverseproxy.StatusPending
|
||||
return rpservice.StatusPending
|
||||
case proto.ProxyStatus_PROXY_STATUS_ACTIVE:
|
||||
return reverseproxy.StatusActive
|
||||
return rpservice.StatusActive
|
||||
case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED:
|
||||
return reverseproxy.StatusTunnelNotCreated
|
||||
return rpservice.StatusTunnelNotCreated
|
||||
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING:
|
||||
return reverseproxy.StatusCertificatePending
|
||||
return rpservice.StatusCertificatePending
|
||||
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED:
|
||||
return reverseproxy.StatusCertificateFailed
|
||||
return rpservice.StatusCertificateFailed
|
||||
case proto.ProxyStatus_PROXY_STATUS_ERROR:
|
||||
return reverseproxy.StatusError
|
||||
return rpservice.StatusError
|
||||
default:
|
||||
return reverseproxy.StatusError
|
||||
return rpservice.StatusError
|
||||
}
|
||||
}
|
||||
|
||||
@@ -727,7 +743,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
||||
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
|
||||
}
|
||||
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
|
||||
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId())
|
||||
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account services: %v", err)
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
|
||||
@@ -790,8 +806,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig {
|
||||
|
||||
// GetOIDCValidationConfig returns the OIDC configuration for token validation
|
||||
// in the format needed by ToProtoMapping.
|
||||
func (s *ProxyServiceServer) GetOIDCValidationConfig() reverseproxy.OIDCValidationConfig {
|
||||
return reverseproxy.OIDCValidationConfig{
|
||||
func (s *ProxyServiceServer) GetOIDCValidationConfig() rpservice.OIDCValidationConfig {
|
||||
return rpservice.OIDCValidationConfig{
|
||||
Issuer: s.oidcConfig.Issuer,
|
||||
Audiences: []string{s.oidcConfig.Audience},
|
||||
KeysLocation: s.oidcConfig.KeysLocation,
|
||||
@@ -850,12 +866,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||
// Find the service by domain to get its signing key
|
||||
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
|
||||
var service *reverseproxy.Service
|
||||
var service *rpservice.Service
|
||||
for _, svc := range services {
|
||||
if svc.Domain == domain {
|
||||
service = svc
|
||||
@@ -921,8 +937,8 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
|
||||
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
||||
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID)
|
||||
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
|
||||
services, err := s.serviceManager.GetAccountServices(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account services: %w", err)
|
||||
}
|
||||
@@ -1043,8 +1059,8 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) {
|
||||
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
|
||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
@@ -1058,7 +1074,7 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri
|
||||
return nil, fmt.Errorf("service not found for domain: %s", domain)
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error {
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
||||
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
|
||||
l := newAuthFailureLimiter()
|
||||
defer l.stop()
|
||||
|
||||
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
|
||||
}
|
||||
|
||||
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
|
||||
l := newAuthFailureLimiter()
|
||||
defer l.stop()
|
||||
|
||||
ip := "192.168.1.1"
|
||||
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||
l.recordFailure(ip)
|
||||
}
|
||||
|
||||
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
|
||||
}
|
||||
|
||||
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
|
||||
l := newAuthFailureLimiter()
|
||||
defer l.stop()
|
||||
|
||||
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||
l.recordFailure("192.168.1.1")
|
||||
}
|
||||
|
||||
assert.True(t, l.isLimited("192.168.1.1"))
|
||||
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
|
||||
}
|
||||
|
||||
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
|
||||
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
|
||||
defer l.stop()
|
||||
|
||||
ip := "10.0.0.1"
|
||||
|
||||
// Exhaust burst
|
||||
for i := 0; i < proxyAuthFailureBurst; i++ {
|
||||
l.recordFailure(ip)
|
||||
}
|
||||
require.True(t, l.isLimited(ip))
|
||||
|
||||
// Wait for token replenishment
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
|
||||
}
|
||||
|
||||
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
|
||||
l := newAuthFailureLimiter()
|
||||
defer l.stop()
|
||||
|
||||
l.recordFailure("10.0.0.1")
|
||||
|
||||
l.mu.Lock()
|
||||
require.Len(t, l.limiters, 1)
|
||||
// Backdate the entry so it looks stale
|
||||
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
||||
l.mu.Unlock()
|
||||
|
||||
l.cleanup()
|
||||
|
||||
l.mu.Lock()
|
||||
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
|
||||
l.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
|
||||
l := newAuthFailureLimiter()
|
||||
defer l.stop()
|
||||
|
||||
l.recordFailure("10.0.0.1")
|
||||
l.recordFailure("10.0.0.2")
|
||||
|
||||
l.mu.Lock()
|
||||
// Only backdate one entry
|
||||
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
|
||||
l.mu.Unlock()
|
||||
|
||||
l.cleanup()
|
||||
|
||||
l.mu.Lock()
|
||||
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
|
||||
assert.Contains(t, l.limiters, "10.0.0.2")
|
||||
l.mu.Unlock()
|
||||
}
|
||||
@@ -1,381 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type mockReverseProxyManager struct {
|
||||
proxiesByAccount map[string][]*reverseproxy.Service
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.proxiesByAccount[accountID], nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||
return []*reverseproxy.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
|
||||
return &reverseproxy.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
return &reverseproxy.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
return &reverseproxy.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
|
||||
return &reverseproxy.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type mockUsersManager struct {
|
||||
users map[string]*types.User
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
user, ok := m.users[userID]
|
||||
if !ok {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func TestValidateUserGroupAccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
userID string
|
||||
proxiesByAccount map[string][]*reverseproxy.Service
|
||||
users map[string]*types.User
|
||||
proxyErr error
|
||||
userErr error
|
||||
expectErr bool
|
||||
expectErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "user not found",
|
||||
domain: "app.example.com",
|
||||
userID: "unknown-user",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||
},
|
||||
users: map[string]*types.User{},
|
||||
expectErr: true,
|
||||
expectErrMsg: "user not found",
|
||||
},
|
||||
{
|
||||
name: "proxy not found in user's account",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1"},
|
||||
},
|
||||
expectErr: true,
|
||||
expectErrMsg: "service not found",
|
||||
},
|
||||
{
|
||||
name: "proxy exists in different account - not accessible",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
|
||||
},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1"},
|
||||
},
|
||||
expectErr: true,
|
||||
expectErrMsg: "service not found",
|
||||
},
|
||||
{
|
||||
name: "no bearer auth configured - same account allows access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
|
||||
},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1"},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "bearer auth disabled - same account allows access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
|
||||
},
|
||||
}},
|
||||
},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1"},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "bearer auth enabled but no groups configured - same account allows access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1"},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "user not in allowed groups",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"group1", "group2"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}},
|
||||
},
|
||||
expectErr: true,
|
||||
expectErrMsg: "not in allowed groups",
|
||||
},
|
||||
{
|
||||
name: "user in one of the allowed groups - allow access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"group1", "group2"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "user in all allowed groups - allow access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"group1", "group2"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "proxy manager error",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: nil,
|
||||
proxyErr: errors.New("database error"),
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1"},
|
||||
},
|
||||
expectErr: true,
|
||||
expectErrMsg: "get account services",
|
||||
},
|
||||
{
|
||||
name: "multiple proxies in account - finds correct one",
|
||||
domain: "app2.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {
|
||||
{Domain: "app1.example.com", AccountID: "account1"},
|
||||
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
|
||||
{Domain: "app3.example.com", AccountID: "account1"},
|
||||
},
|
||||
},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1"},
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &ProxyServiceServer{
|
||||
reverseProxyManager: &mockReverseProxyManager{
|
||||
proxiesByAccount: tt.proxiesByAccount,
|
||||
err: tt.proxyErr,
|
||||
},
|
||||
usersManager: &mockUsersManager{
|
||||
users: tt.users,
|
||||
err: tt.userErr,
|
||||
},
|
||||
}
|
||||
|
||||
err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID)
|
||||
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.expectErrMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
domain string
|
||||
proxiesByAccount map[string][]*reverseproxy.Service
|
||||
err error
|
||||
expectProxy bool
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "proxy found",
|
||||
accountID: "account1",
|
||||
domain: "app.example.com",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {
|
||||
{Domain: "other.example.com", AccountID: "account1"},
|
||||
{Domain: "app.example.com", AccountID: "account1"},
|
||||
},
|
||||
},
|
||||
expectProxy: true,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "proxy not found in account",
|
||||
accountID: "account1",
|
||||
domain: "unknown.example.com",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||
},
|
||||
expectProxy: false,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty proxy list for account",
|
||||
accountID: "account1",
|
||||
domain: "app.example.com",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
||||
expectProxy: false,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "manager error",
|
||||
accountID: "account1",
|
||||
domain: "app.example.com",
|
||||
proxiesByAccount: nil,
|
||||
err: errors.New("database error"),
|
||||
expectProxy: false,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &ProxyServiceServer{
|
||||
reverseProxyManager: &mockReverseProxyManager{
|
||||
proxiesByAccount: tt.proxiesByAccount,
|
||||
err: tt.err,
|
||||
},
|
||||
}
|
||||
|
||||
proxy, err := server.getAccountServiceByDomain(context.Background(), tt.accountID, tt.domain)
|
||||
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, proxy)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, proxy)
|
||||
assert.Equal(t, tt.domain, proxy.Domain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,232 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
||||
// and returns the channel where messages will be received.
|
||||
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
|
||||
ch := make(chan *proto.ProxyMapping, 10)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
address: clusterAddr,
|
||||
sendChan: ch,
|
||||
}
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
|
||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
return msg
|
||||
case <-time.After(time.Second):
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||
defer tokenStore.Close()
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||
}
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
const numProxies = 3
|
||||
|
||||
channels := make([]chan *proto.ProxyMapping, numProxies)
|
||||
for i := range numProxies {
|
||||
id := "proxy-" + string(rune('a'+i))
|
||||
channels[i] = registerFakeProxy(s, id, cluster)
|
||||
}
|
||||
|
||||
update := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-1",
|
||||
AccountId: "account-1",
|
||||
Domain: "test.example.com",
|
||||
Path: []*proto.PathMapping{
|
||||
{Path: "/", Target: "http://10.0.0.1:8080/"},
|
||||
},
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(update, cluster)
|
||||
|
||||
tokens := make([]string, numProxies)
|
||||
for i, ch := range channels {
|
||||
msg := drainChannel(ch)
|
||||
require.NotNil(t, msg, "proxy %d should receive a message", i)
|
||||
assert.Equal(t, update.Domain, msg.Domain)
|
||||
assert.Equal(t, update.Id, msg.Id)
|
||||
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
||||
tokens[i] = msg.AuthToken
|
||||
}
|
||||
|
||||
// All tokens must be unique
|
||||
tokenSet := make(map[string]struct{})
|
||||
for i, tok := range tokens {
|
||||
_, exists := tokenSet[tok]
|
||||
assert.False(t, exists, "proxy %d got duplicate token", i)
|
||||
tokenSet[tok] = struct{}{}
|
||||
}
|
||||
|
||||
// Each token must be independently consumable
|
||||
for i, tok := range tokens {
|
||||
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
|
||||
assert.NoError(t, err, "proxy %d token should validate successfully", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||
defer tokenStore.Close()
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||
}
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
||||
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
||||
|
||||
update := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||
Id: "service-1",
|
||||
AccountId: "account-1",
|
||||
Domain: "test.example.com",
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(update, cluster)
|
||||
|
||||
msg1 := drainChannel(ch1)
|
||||
msg2 := drainChannel(ch2)
|
||||
require.NotNil(t, msg1)
|
||||
require.NotNil(t, msg2)
|
||||
|
||||
// Delete operations should not generate tokens
|
||||
assert.Empty(t, msg1.AuthToken)
|
||||
assert.Empty(t, msg2.AuthToken)
|
||||
|
||||
// No tokens should have been created
|
||||
assert.Equal(t, 0, tokenStore.GetTokenCount())
|
||||
}
|
||||
|
||||
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||
defer tokenStore.Close()
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||
}
|
||||
|
||||
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
||||
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
||||
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
||||
|
||||
update := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-1",
|
||||
AccountId: "account-1",
|
||||
Domain: "test.example.com",
|
||||
}
|
||||
|
||||
s.SendServiceUpdate(update)
|
||||
|
||||
msg1 := drainChannel(ch1)
|
||||
msg2 := drainChannel(ch2)
|
||||
require.NotNil(t, msg1)
|
||||
require.NotNil(t, msg2)
|
||||
|
||||
assert.NotEmpty(t, msg1.AuthToken)
|
||||
assert.NotEmpty(t, msg2.AuthToken)
|
||||
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
|
||||
|
||||
// Both tokens should validate
|
||||
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
|
||||
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
|
||||
}
|
||||
|
||||
// generateState creates a state using the same format as GetOIDCURL.
|
||||
func generateState(s *ProxyServiceServer, redirectURL string) string {
|
||||
nonce := make([]byte, 16)
|
||||
_, _ = rand.Read(nonce)
|
||||
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
|
||||
|
||||
payload := redirectURL + "|" + nonceB64
|
||||
hmacSum := s.generateHMAC(payload)
|
||||
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
|
||||
}
|
||||
|
||||
func TestOAuthState_NeverTheSame(t *testing.T) {
|
||||
s := &ProxyServiceServer{
|
||||
oidcConfig: ProxyOIDCConfig{
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
},
|
||||
}
|
||||
|
||||
redirectURL := "https://app.example.com/callback"
|
||||
|
||||
// Generate 100 states for the same redirect URL
|
||||
states := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
state := generateState(s, redirectURL)
|
||||
|
||||
// State must have 3 parts: base64(url)|nonce|hmac
|
||||
parts := strings.Split(state, "|")
|
||||
require.Equal(t, 3, len(parts), "state must have 3 parts")
|
||||
|
||||
// State must be unique
|
||||
require.False(t, states[state], "state %d is a duplicate", i)
|
||||
states[state] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||
s := &ProxyServiceServer{
|
||||
oidcConfig: ProxyOIDCConfig{
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
},
|
||||
}
|
||||
|
||||
// Old format had only 2 parts: base64(url)|hmac
|
||||
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||
|
||||
_, _, err := s.ValidateState("base64url|hmac")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid state format")
|
||||
}
|
||||
|
||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||
s := &ProxyServiceServer{
|
||||
oidcConfig: ProxyOIDCConfig{
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
},
|
||||
}
|
||||
|
||||
// Store with tampered HMAC
|
||||
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||
|
||||
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid state signature")
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
testingServerKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
}
|
||||
|
||||
testingClientKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputFlow *config.DeviceAuthorizationFlow
|
||||
expectedFlow *mgmtProto.DeviceAuthorizationFlow
|
||||
expectedErrFunc require.ErrorAssertionFunc
|
||||
expectedErrMSG string
|
||||
expectedComparisonFunc require.ComparisonAssertionFunc
|
||||
expectedComparisonMSG string
|
||||
}{
|
||||
{
|
||||
name: "Testing No Device Flow Config",
|
||||
inputFlow: nil,
|
||||
expectedErrFunc: require.Error,
|
||||
expectedErrMSG: "should return error",
|
||||
},
|
||||
{
|
||||
name: "Testing Invalid Device Flow Provider Config",
|
||||
inputFlow: &config.DeviceAuthorizationFlow{
|
||||
Provider: "NoNe",
|
||||
ProviderConfig: config.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedErrFunc: require.Error,
|
||||
expectedErrMSG: "should return error",
|
||||
},
|
||||
{
|
||||
name: "Testing Full Device Flow Config",
|
||||
inputFlow: &config.DeviceAuthorizationFlow{
|
||||
Provider: "hosted",
|
||||
ProviderConfig: config.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
|
||||
Provider: 0,
|
||||
ProviderConfig: &mgmtProto.ProviderConfig{
|
||||
ClientID: "test",
|
||||
},
|
||||
},
|
||||
expectedErrFunc: require.NoError,
|
||||
expectedErrMSG: "should not return error",
|
||||
expectedComparisonFunc: require.Equal,
|
||||
expectedComparisonMSG: "should match",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
mgmtServer := &Server{
|
||||
secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey},
|
||||
config: &config.Config{
|
||||
DeviceAuthorizationFlow: testCase.inputFlow,
|
||||
},
|
||||
}
|
||||
|
||||
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
|
||||
key, err := mgmtServer.secretsManager.GetWGKey()
|
||||
require.NoError(t, err, "should be able to get server key")
|
||||
|
||||
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message)
|
||||
require.NoError(t, err, "should be able to encrypt message")
|
||||
|
||||
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
|
||||
context.TODO(),
|
||||
&mgmtProto.EncryptedMessage{
|
||||
WgPubKey: testingClientKey.PublicKey().String(),
|
||||
Body: encryptedMSG,
|
||||
},
|
||||
)
|
||||
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
|
||||
if testCase.expectedComparisonFunc != nil {
|
||||
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
|
||||
|
||||
err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
|
||||
require.NoError(t, err, "should be able to decrypt")
|
||||
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,250 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"hash"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var TurnTestHost = &config.Host{
|
||||
Proto: config.UDP,
|
||||
URI: "turn:turn.netbird.io:77777",
|
||||
Username: "username",
|
||||
Password: "",
|
||||
}
|
||||
|
||||
func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
ttl := util.Duration{Duration: time.Hour}
|
||||
secret := "some_secret"
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
|
||||
rc := &config.Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
require.NoError(t, err)
|
||||
|
||||
turnCredentials, err := tested.GenerateTurnToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
if turnCredentials.Payload == "" {
|
||||
t.Errorf("expected generated TURN username not to be empty, got empty")
|
||||
}
|
||||
if turnCredentials.Signature == "" {
|
||||
t.Errorf("expected generated TURN password not to be empty, got empty")
|
||||
}
|
||||
|
||||
validateMAC(t, sha1.New, turnCredentials.Payload, turnCredentials.Signature, []byte(secret))
|
||||
|
||||
relayCredentials, err := tested.GenerateRelayToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
if relayCredentials.Payload == "" {
|
||||
t.Errorf("expected generated relay payload not to be empty, got empty")
|
||||
}
|
||||
if relayCredentials.Signature == "" {
|
||||
t.Errorf("expected generated relay signature not to be empty, got empty")
|
||||
}
|
||||
|
||||
hashedSecret := sha256.Sum256([]byte(secret))
|
||||
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
|
||||
}
|
||||
|
||||
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
ttl := util.Duration{Duration: 2 * time.Second}
|
||||
secret := "some_secret"
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
peer := "some_peer"
|
||||
updateChannel := peersManager.CreateChannel(context.Background(), peer)
|
||||
|
||||
rc := &config.Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tested.SetupRefresh(ctx, "someAccountID", peer)
|
||||
|
||||
if _, ok := tested.turnCancelMap[peer]; !ok {
|
||||
t.Errorf("expecting peer to be present in the turn cancel map, got not present")
|
||||
}
|
||||
|
||||
if _, ok := tested.relayCancelMap[peer]; !ok {
|
||||
t.Errorf("expecting peer to be present in the relay cancel map, got not present")
|
||||
}
|
||||
|
||||
var updates []*network_map.UpdateMessage
|
||||
|
||||
loop:
|
||||
for timeout := time.After(5 * time.Second); ; {
|
||||
select {
|
||||
case update := <-updateChannel:
|
||||
updates = append(updates, update)
|
||||
case <-timeout:
|
||||
break loop
|
||||
}
|
||||
|
||||
if len(updates) >= 2 {
|
||||
break loop
|
||||
}
|
||||
}
|
||||
|
||||
if len(updates) < 2 {
|
||||
t.Errorf("expecting at least 2 peer credentials updates, got %v", len(updates))
|
||||
}
|
||||
|
||||
var turnUpdates, relayUpdates int
|
||||
var firstTurnUpdate, secondTurnUpdate *proto.ProtectedHostConfig
|
||||
var firstRelayUpdate, secondRelayUpdate *proto.RelayConfig
|
||||
|
||||
for _, update := range updates {
|
||||
if turns := update.Update.GetNetbirdConfig().GetTurns(); len(turns) > 0 {
|
||||
turnUpdates++
|
||||
if turnUpdates == 1 {
|
||||
firstTurnUpdate = turns[0]
|
||||
} else {
|
||||
secondTurnUpdate = turns[0]
|
||||
}
|
||||
}
|
||||
if relay := update.Update.GetNetbirdConfig().GetRelay(); relay != nil {
|
||||
// avoid updating on turn updates since they also send relay credentials
|
||||
if update.Update.GetNetbirdConfig().GetTurns() == nil {
|
||||
relayUpdates++
|
||||
if relayUpdates == 1 {
|
||||
firstRelayUpdate = relay
|
||||
} else {
|
||||
secondRelayUpdate = relay
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if turnUpdates < 1 {
|
||||
t.Errorf("expecting at least 1 TURN credential update, got %v", turnUpdates)
|
||||
}
|
||||
if relayUpdates < 1 {
|
||||
t.Errorf("expecting at least 1 relay credential update, got %v", relayUpdates)
|
||||
}
|
||||
|
||||
if firstTurnUpdate != nil && secondTurnUpdate != nil {
|
||||
if firstTurnUpdate.Password == secondTurnUpdate.Password {
|
||||
t.Errorf("expecting first TURN credential update password %v to be different from second, got equal", firstTurnUpdate.Password)
|
||||
}
|
||||
}
|
||||
|
||||
if firstRelayUpdate != nil && secondRelayUpdate != nil {
|
||||
if firstRelayUpdate.TokenSignature == secondRelayUpdate.TokenSignature {
|
||||
t.Errorf("expecting first relay credential update signature %v to be different from second, got equal", firstRelayUpdate.TokenSignature)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
||||
ttl := util.Duration{Duration: time.Hour}
|
||||
secret := "some_secret"
|
||||
peersManager := update_channel.NewPeersUpdateManager(nil)
|
||||
peer := "some_peer"
|
||||
|
||||
rc := &config.Relay{
|
||||
Addresses: []string{"localhost:0"},
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
require.NoError(t, err)
|
||||
|
||||
tested.SetupRefresh(context.Background(), "someAccountID", peer)
|
||||
if _, ok := tested.turnCancelMap[peer]; !ok {
|
||||
t.Errorf("expecting peer to be present in turn cancel map, got not present")
|
||||
}
|
||||
if _, ok := tested.relayCancelMap[peer]; !ok {
|
||||
t.Errorf("expecting peer to be present in relay cancel map, got not present")
|
||||
}
|
||||
|
||||
tested.CancelRefresh(peer)
|
||||
if _, ok := tested.turnCancelMap[peer]; ok {
|
||||
t.Errorf("expecting peer to be not present in turn cancel map, got present")
|
||||
}
|
||||
if _, ok := tested.relayCancelMap[peer]; ok {
|
||||
t.Errorf("expecting peer to be not present in relay cancel map, got present")
|
||||
}
|
||||
}
|
||||
|
||||
func validateMAC(t *testing.T, algo func() hash.Hash, username string, actualMAC string, key []byte) {
|
||||
t.Helper()
|
||||
mac := hmac.New(algo, key)
|
||||
|
||||
_, err := mac.Write([]byte(username))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedMAC := mac.Sum(nil)
|
||||
decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
equal := hmac.Equal(decodedMAC, expectedMAC)
|
||||
|
||||
if !equal {
|
||||
t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC)
|
||||
}
|
||||
}
|
||||
@@ -1,587 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
shouldSend := debouncer.ProcessUpdate(update)
|
||||
|
||||
if !shouldSend {
|
||||
t.Error("First update should be sent immediately")
|
||||
}
|
||||
|
||||
if debouncer.TimerChannel() == nil {
|
||||
t.Error("Timer should be started after first update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update3 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// First update should be sent immediately
|
||||
if !debouncer.ProcessUpdate(update1) {
|
||||
t.Error("First update should be sent immediately")
|
||||
}
|
||||
|
||||
// Rapid subsequent updates should be coalesced
|
||||
if debouncer.ProcessUpdate(update2) {
|
||||
t.Error("Second rapid update should not be sent immediately")
|
||||
}
|
||||
|
||||
if debouncer.ProcessUpdate(update3) {
|
||||
t.Error("Third rapid update should not be sent immediately")
|
||||
}
|
||||
|
||||
// Wait for debounce period
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) != 1 {
|
||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||
}
|
||||
if pendingUpdates[0] != update3 {
|
||||
t.Error("Should get the last update (update3)")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// Send first update
|
||||
debouncer.ProcessUpdate(update1)
|
||||
|
||||
// Send second update within debounce period
|
||||
debouncer.ProcessUpdate(update2)
|
||||
|
||||
// Wait for timer
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) != 1 {
|
||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||
}
|
||||
if pendingUpdates[0] != update2 {
|
||||
t.Error("Should get the last update")
|
||||
}
|
||||
if pendingUpdates[0] == update1 {
|
||||
t.Error("Should not get the first update")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update3 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// Send first update
|
||||
debouncer.ProcessUpdate(update1)
|
||||
|
||||
// Wait a bit, but not the full debounce period
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
|
||||
// Send second update - should reset timer
|
||||
debouncer.ProcessUpdate(update2)
|
||||
|
||||
// Wait a bit more
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
|
||||
// Send third update - should reset timer again
|
||||
debouncer.ProcessUpdate(update3)
|
||||
|
||||
// Now wait for the timer (should fire after last update's reset)
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) != 1 {
|
||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||
}
|
||||
if pendingUpdates[0] != update3 {
|
||||
t.Error("Should get the last update (update3)")
|
||||
}
|
||||
// Timer should be restarted since there was a pending update
|
||||
if debouncer.TimerChannel() == nil {
|
||||
t.Error("Timer should be restarted after sending pending update")
|
||||
}
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update3 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// First update sent immediately
|
||||
debouncer.ProcessUpdate(update1)
|
||||
|
||||
// Second update coalesced
|
||||
debouncer.ProcessUpdate(update2)
|
||||
|
||||
// Wait for timer to expire
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
|
||||
if len(pendingUpdates) == 0 {
|
||||
t.Fatal("Should have pending update")
|
||||
}
|
||||
|
||||
// After sending pending update, timer is restarted, so next update is NOT immediate
|
||||
if debouncer.ProcessUpdate(update3) {
|
||||
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
|
||||
}
|
||||
|
||||
// Wait for the restarted timer and verify update3 is pending
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
finalUpdates := debouncer.GetPendingUpdates()
|
||||
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
|
||||
t.Error("Should get update3 as pending")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Timer should have fired for restarted timer")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
|
||||
update := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// Send update to start timer
|
||||
debouncer.ProcessUpdate(update)
|
||||
|
||||
// Stop should clean up
|
||||
debouncer.Stop()
|
||||
|
||||
// Multiple stops should be safe
|
||||
debouncer.Stop()
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
// Simulate high-frequency updates
|
||||
var lastUpdate *network_map.UpdateMessage
|
||||
sentImmediately := 0
|
||||
for i := 0; i < 100; i++ {
|
||||
update := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: uint64(i),
|
||||
},
|
||||
},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
lastUpdate = update
|
||||
if debouncer.ProcessUpdate(update) {
|
||||
sentImmediately++
|
||||
}
|
||||
time.Sleep(1 * time.Millisecond) // Very rapid updates
|
||||
}
|
||||
|
||||
// Only first update should be sent immediately
|
||||
if sentImmediately != 1 {
|
||||
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
|
||||
}
|
||||
|
||||
// Wait for debounce period
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) != 1 {
|
||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||
}
|
||||
if pendingUpdates[0] != lastUpdate {
|
||||
t.Error("Should get the very last update")
|
||||
}
|
||||
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
|
||||
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// Send first update
|
||||
if !debouncer.ProcessUpdate(update) {
|
||||
t.Error("First update should be sent immediately")
|
||||
}
|
||||
|
||||
// Wait for timer to expire with no additional updates (true quiet period)
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) != 0 {
|
||||
t.Error("Should have no pending updates")
|
||||
}
|
||||
// After true quiet period, timer should be cleared
|
||||
if debouncer.TimerChannel() != nil {
|
||||
t.Error("Timer should be cleared after quiet period")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
updates := make([]*network_map.UpdateMessage, 5)
|
||||
for i := range updates {
|
||||
updates[i] = &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: uint64(i),
|
||||
},
|
||||
},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
}
|
||||
|
||||
// First update sent immediately
|
||||
debouncer.ProcessUpdate(updates[0])
|
||||
|
||||
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
|
||||
debouncer.ProcessUpdate(updates[1])
|
||||
debouncer.ProcessUpdate(updates[2])
|
||||
debouncer.ProcessUpdate(updates[3])
|
||||
debouncer.ProcessUpdate(updates[4])
|
||||
|
||||
// Wait for debounce
|
||||
<-debouncer.TimerChannel()
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
|
||||
if len(pendingUpdates) != 1 {
|
||||
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
|
||||
}
|
||||
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
|
||||
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
update1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
update2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
// First update sent immediately
|
||||
if !debouncer.ProcessUpdate(update1) {
|
||||
t.Error("First update should be sent immediately")
|
||||
}
|
||||
|
||||
// Wait for timer without sending any more updates (true quiet period)
|
||||
<-debouncer.TimerChannel()
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
|
||||
if len(pendingUpdates) != 0 {
|
||||
t.Error("Should have no pending updates during quiet period")
|
||||
}
|
||||
|
||||
// After true quiet period, next update should be sent immediately
|
||||
if !debouncer.ProcessUpdate(update2) {
|
||||
t.Error("Update after true quiet period should be sent immediately")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
// Simulate continuous high-frequency updates
|
||||
for i := 0; i < 10; i++ {
|
||||
update := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: uint64(i),
|
||||
},
|
||||
},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
// First one sent immediately
|
||||
if !debouncer.ProcessUpdate(update) {
|
||||
t.Error("First update should be sent immediately")
|
||||
}
|
||||
} else {
|
||||
// All others should be coalesced (not sent immediately)
|
||||
if debouncer.ProcessUpdate(update) {
|
||||
t.Errorf("Update %d should not be sent immediately", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait a bit but send next update before debounce expires
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Now wait for final debounce
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
if len(pendingUpdates) == 0 {
|
||||
t.Fatal("Should have the last update pending")
|
||||
}
|
||||
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
|
||||
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
netmapUpdate := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
tokenUpdate1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||
MessageType: network_map.MessageTypeControlConfig,
|
||||
}
|
||||
tokenUpdate2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||
MessageType: network_map.MessageTypeControlConfig,
|
||||
}
|
||||
|
||||
// First update sent immediately
|
||||
debouncer.ProcessUpdate(netmapUpdate)
|
||||
|
||||
// Send multiple control config updates - they should all be queued
|
||||
debouncer.ProcessUpdate(tokenUpdate1)
|
||||
debouncer.ProcessUpdate(tokenUpdate2)
|
||||
|
||||
// Wait for debounce period
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
// Should get both control config updates
|
||||
if len(pendingUpdates) != 2 {
|
||||
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
|
||||
}
|
||||
// Control configs should come first
|
||||
if pendingUpdates[0] != tokenUpdate1 {
|
||||
t.Error("First pending update should be tokenUpdate1")
|
||||
}
|
||||
if pendingUpdates[1] != tokenUpdate2 {
|
||||
t.Error("Second pending update should be tokenUpdate2")
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
netmapUpdate1 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
netmapUpdate2 := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
tokenUpdate := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||
MessageType: network_map.MessageTypeControlConfig,
|
||||
}
|
||||
|
||||
// First update sent immediately
|
||||
debouncer.ProcessUpdate(netmapUpdate1)
|
||||
|
||||
// Send token update and network map update
|
||||
debouncer.ProcessUpdate(tokenUpdate)
|
||||
debouncer.ProcessUpdate(netmapUpdate2)
|
||||
|
||||
// Wait for debounce period
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
// Should get 2 updates in order: token, then network map
|
||||
if len(pendingUpdates) != 2 {
|
||||
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
|
||||
}
|
||||
// Token update should come first (preserves order)
|
||||
if pendingUpdates[0] != tokenUpdate {
|
||||
t.Error("First pending update should be tokenUpdate")
|
||||
}
|
||||
// Network map update should come second
|
||||
if pendingUpdates[1] != netmapUpdate2 {
|
||||
t.Error("Second pending update should be netmapUpdate2")
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
|
||||
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
|
||||
defer debouncer.Stop()
|
||||
|
||||
// Simulate: 50 network maps -> 1 control config -> 50 network maps
|
||||
// Expected result: 3 messages (netmap, controlConfig, netmap)
|
||||
|
||||
// Send first network map immediately
|
||||
firstNetmap := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
if !debouncer.ProcessUpdate(firstNetmap) {
|
||||
t.Error("First update should be sent immediately")
|
||||
}
|
||||
|
||||
// Send 49 more network maps (will be coalesced to last one)
|
||||
var lastNetmapBatch1 *network_map.UpdateMessage
|
||||
for i := 1; i < 50; i++ {
|
||||
lastNetmapBatch1 = &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
debouncer.ProcessUpdate(lastNetmapBatch1)
|
||||
}
|
||||
|
||||
// Send 1 control config
|
||||
controlConfig := &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
|
||||
MessageType: network_map.MessageTypeControlConfig,
|
||||
}
|
||||
debouncer.ProcessUpdate(controlConfig)
|
||||
|
||||
// Send 50 more network maps (will be coalesced to last one)
|
||||
var lastNetmapBatch2 *network_map.UpdateMessage
|
||||
for i := 50; i < 100; i++ {
|
||||
lastNetmapBatch2 = &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
}
|
||||
debouncer.ProcessUpdate(lastNetmapBatch2)
|
||||
}
|
||||
|
||||
// Wait for debounce period
|
||||
select {
|
||||
case <-debouncer.TimerChannel():
|
||||
pendingUpdates := debouncer.GetPendingUpdates()
|
||||
// Should get exactly 3 updates: netmap, controlConfig, netmap
|
||||
if len(pendingUpdates) != 3 {
|
||||
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
|
||||
}
|
||||
// First should be the last netmap from batch 1
|
||||
if pendingUpdates[0] != lastNetmapBatch1 {
|
||||
t.Error("First pending update should be last netmap from batch 1")
|
||||
}
|
||||
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
|
||||
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
|
||||
}
|
||||
// Second should be the control config
|
||||
if pendingUpdates[1] != controlConfig {
|
||||
t.Error("Second pending update should be control config")
|
||||
}
|
||||
// Third should be the last netmap from batch 2
|
||||
if pendingUpdates[2] != lastNetmapBatch2 {
|
||||
t.Error("Third pending update should be last netmap from batch 2")
|
||||
}
|
||||
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
|
||||
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("Timer should have fired")
|
||||
}
|
||||
}
|
||||
@@ -1,304 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type validateSessionTestSetup struct {
|
||||
proxyService *ProxyServiceServer
|
||||
store store.Store
|
||||
cleanup func()
|
||||
}
|
||||
|
||||
func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyManager := &testValidateSessionProxyManager{store: testStore}
|
||||
usersManager := &testValidateSessionUsersManager{store: testStore}
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
|
||||
proxyService.SetProxyManager(proxyManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
|
||||
return &validateSessionTestSetup{
|
||||
proxyService: proxyService,
|
||||
store: testStore,
|
||||
cleanup: storeCleanup,
|
||||
}
|
||||
}
|
||||
|
||||
func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) {
|
||||
t.Helper()
|
||||
|
||||
pubKey, privKey := generateSessionKeyPair(t)
|
||||
|
||||
testProxy := &reverseproxy.Service{
|
||||
ID: "testProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Test Proxy",
|
||||
Domain: "test-proxy.example.com",
|
||||
Enabled: true,
|
||||
SessionPrivateKey: privKey,
|
||||
SessionPublicKey: pubKey,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
||||
|
||||
restrictedProxy := &reverseproxy.Service{
|
||||
ID: "restrictedProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Restricted Proxy",
|
||||
Domain: "restricted-proxy.example.com",
|
||||
Enabled: true,
|
||||
SessionPrivateKey: privKey,
|
||||
SessionPublicKey: pubKey,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"allowedGroupId"},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
|
||||
}
|
||||
|
||||
func generateSessionKeyPair(t *testing.T) (string, string) {
|
||||
t.Helper()
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
return base64.StdEncoding.EncodeToString(pub), base64.StdEncoding.EncodeToString(priv)
|
||||
}
|
||||
|
||||
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
|
||||
t.Helper()
|
||||
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
|
||||
require.NoError(t, err)
|
||||
return token
|
||||
}
|
||||
|
||||
func TestValidateSession_UserAllowed(t *testing.T) {
|
||||
setup := setupValidateSessionTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com")
|
||||
|
||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||
Domain: "test-proxy.example.com",
|
||||
SessionToken: token,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.Valid, "User should be allowed access")
|
||||
assert.Equal(t, "allowedUserId", resp.UserId)
|
||||
assert.Empty(t, resp.DeniedReason)
|
||||
}
|
||||
|
||||
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
|
||||
setup := setupValidateSessionTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com")
|
||||
|
||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||
Domain: "restricted-proxy.example.com",
|
||||
SessionToken: token,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.Valid, "User not in group should be denied")
|
||||
assert.Equal(t, "not_in_group", resp.DeniedReason)
|
||||
assert.Equal(t, "nonGroupUserId", resp.UserId)
|
||||
}
|
||||
|
||||
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
|
||||
setup := setupValidateSessionTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com")
|
||||
|
||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||
Domain: "test-proxy.example.com",
|
||||
SessionToken: token,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.Valid, "User in different account should be denied")
|
||||
assert.Equal(t, "account_mismatch", resp.DeniedReason)
|
||||
}
|
||||
|
||||
func TestValidateSession_UserNotFound(t *testing.T) {
|
||||
setup := setupValidateSessionTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com")
|
||||
|
||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||
Domain: "test-proxy.example.com",
|
||||
SessionToken: token,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.Valid, "Non-existent user should be denied")
|
||||
assert.Equal(t, "user_not_found", resp.DeniedReason)
|
||||
}
|
||||
|
||||
func TestValidateSession_ProxyNotFound(t *testing.T) {
|
||||
setup := setupValidateSessionTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com")
|
||||
|
||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||
Domain: "unknown-proxy.example.com",
|
||||
SessionToken: token,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.Valid, "Unknown proxy should be denied")
|
||||
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
|
||||
}
|
||||
|
||||
func TestValidateSession_InvalidToken(t *testing.T) {
|
||||
setup := setupValidateSessionTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||
Domain: "test-proxy.example.com",
|
||||
SessionToken: "invalid-token",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.Valid, "Invalid token should be denied")
|
||||
assert.Equal(t, "invalid_token", resp.DeniedReason)
|
||||
}
|
||||
|
||||
func TestValidateSession_MissingDomain(t *testing.T) {
|
||||
setup := setupValidateSessionTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||
SessionToken: "some-token",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.Valid)
|
||||
assert.Contains(t, resp.DeniedReason, "missing")
|
||||
}
|
||||
|
||||
func TestValidateSession_MissingToken(t *testing.T) {
|
||||
setup := setupValidateSessionTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
|
||||
Domain: "test-proxy.example.com",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.Valid)
|
||||
assert.Contains(t, resp.DeniedReason, "missing")
|
||||
}
|
||||
|
||||
type testValidateSessionProxyManager struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type testValidateSessionUsersManager struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
|
||||
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
}
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
@@ -83,9 +83,9 @@ type DefaultAccountManager struct {
|
||||
|
||||
requestBuffer *AccountRequestBuffer
|
||||
|
||||
proxyController port_forwarding.Controller
|
||||
settingsManager settings.Manager
|
||||
reverseProxyManager reverseproxy.Manager
|
||||
proxyController port_forwarding.Controller
|
||||
settingsManager settings.Manager
|
||||
serviceManager service.Manager
|
||||
|
||||
// config contains the management server configuration
|
||||
config *nbconfig.Config
|
||||
@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
|
||||
|
||||
var _ account.Manager = (*DefaultAccountManager)(nil)
|
||||
|
||||
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
||||
am.reverseProxyManager = serviceManager
|
||||
func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) {
|
||||
am.serviceManager = serviceManager
|
||||
}
|
||||
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
@@ -394,7 +394,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
|
||||
}
|
||||
if reloadReverseProxy {
|
||||
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||
if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -140,5 +140,5 @@ type Manager interface {
|
||||
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||
SetServiceManager(serviceManager reverseproxy.Manager)
|
||||
SetServiceManager(serviceManager service.Manager)
|
||||
}
|
||||
|
||||
@@ -27,8 +27,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
|
||||
@@ -703,7 +703,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving group linked to network router", func(t *testing.T) {
|
||||
permissionsManager := permissions.NewManager(manager.Store)
|
||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
|
||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
@@ -73,7 +73,7 @@ const (
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -173,8 +173,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router)
|
||||
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
|
||||
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
|
||||
}
|
||||
|
||||
// Register OAuth callback handler for proxy authentication
|
||||
|
||||
@@ -18,8 +18,8 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -190,7 +190,8 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
|
||||
oidcServer := newFakeOIDCServer()
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
usersManager := users.NewManager(testStore)
|
||||
|
||||
|
||||
@@ -12,7 +12,8 @@ import (
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
@@ -91,12 +92,16 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
}
|
||||
|
||||
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
|
||||
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager)
|
||||
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager)
|
||||
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
|
||||
proxyServiceServer.SetProxyManager(reverseProxyManager)
|
||||
am.SetServiceManager(reverseProxyManager)
|
||||
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy token store: %v", err)
|
||||
}
|
||||
proxyMgr := proxymanager.NewManager(store)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
|
||||
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
|
||||
proxyServiceServer.SetProxyManager(serviceManager)
|
||||
am.SetServiceManager(serviceManager)
|
||||
|
||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
@@ -114,7 +119,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
@@ -148,7 +148,7 @@ type MockAccountManager struct {
|
||||
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
||||
func (am *MockAccountManager) SetServiceManager(serviceManager service.Manager) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -33,23 +33,23 @@ type Manager interface {
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
groupsManager groups.Manager
|
||||
accountManager account.Manager
|
||||
reverseProxyManager reverseproxy.Manager
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
groupsManager groups.Manager
|
||||
accountManager account.Manager
|
||||
serviceManager service.Manager
|
||||
}
|
||||
|
||||
type mockManager struct {
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager {
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
permissionsManager: permissionsManager,
|
||||
groupsManager: groupsManager,
|
||||
accountManager: accountManager,
|
||||
reverseProxyManager: reverseproxyManager,
|
||||
store: store,
|
||||
permissionsManager: permissionsManager,
|
||||
groupsManager: groupsManager,
|
||||
accountManager: accountManager,
|
||||
serviceManager: reverseproxyManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
|
||||
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
|
||||
go func() {
|
||||
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
|
||||
err := m.serviceManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
|
||||
}
|
||||
@@ -322,7 +322,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
|
||||
serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if resource is used by service: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -31,8 +31,8 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
|
||||
require.NoError(t, err)
|
||||
@@ -54,8 +54,8 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
|
||||
require.Error(t, err)
|
||||
@@ -76,8 +76,8 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
|
||||
require.NoError(t, err)
|
||||
@@ -98,8 +98,8 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
|
||||
require.Error(t, err)
|
||||
@@ -123,8 +123,8 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
|
||||
require.NoError(t, err)
|
||||
@@ -147,8 +147,8 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
|
||||
require.Error(t, err)
|
||||
@@ -176,9 +176,9 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||
require.NoError(t, err)
|
||||
@@ -205,8 +205,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -234,8 +234,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -262,8 +262,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -294,9 +294,9 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||
require.NoError(t, err)
|
||||
@@ -329,8 +329,8 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -361,8 +361,8 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -392,8 +392,8 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -416,9 +416,9 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
reverseProxyManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
serviceManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
|
||||
require.NoError(t, err)
|
||||
@@ -440,8 +440,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -493,7 +493,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
var settings *types.Settings
|
||||
var eventsToStore []func()
|
||||
|
||||
serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID)
|
||||
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if resource is used by service: %w", err)
|
||||
}
|
||||
|
||||
@@ -28,9 +28,10 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -131,8 +132,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -2063,7 +2064,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
|
||||
return checks, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
|
||||
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
|
||||
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
|
||||
pass_host_header, rewrite_redirects, session_private_key, session_public_key
|
||||
@@ -2078,8 +2079,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
return nil, err
|
||||
}
|
||||
|
||||
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) {
|
||||
var s reverseproxy.Service
|
||||
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
|
||||
var s rpservice.Service
|
||||
var auth []byte
|
||||
var createdAt, certIssuedAt sql.NullTime
|
||||
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
|
||||
@@ -2109,7 +2110,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
}
|
||||
}
|
||||
|
||||
s.Meta = reverseproxy.ServiceMeta{}
|
||||
s.Meta = rpservice.ServiceMeta{}
|
||||
if createdAt.Valid {
|
||||
s.Meta.CreatedAt = createdAt.Time
|
||||
}
|
||||
@@ -2129,7 +2130,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
s.SessionPublicKey = sessionPublicKey.String
|
||||
}
|
||||
|
||||
s.Targets = []*reverseproxy.Target{}
|
||||
s.Targets = []*rpservice.Target{}
|
||||
return &s, nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -2141,7 +2142,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
}
|
||||
|
||||
serviceIDs := make([]string, len(services))
|
||||
serviceMap := make(map[string]*reverseproxy.Service)
|
||||
serviceMap := make(map[string]*rpservice.Service)
|
||||
for i, s := range services {
|
||||
serviceIDs[i] = s.ID
|
||||
serviceMap[s.ID] = s
|
||||
@@ -2152,8 +2153,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) {
|
||||
var t reverseproxy.Target
|
||||
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*rpservice.Target, error) {
|
||||
var t rpservice.Target
|
||||
var path sql.NullString
|
||||
err := row.Scan(
|
||||
&t.ID,
|
||||
@@ -4825,7 +4826,7 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (s *SqlStore) CreateService(ctx context.Context, service *rpservice.Service) error {
|
||||
serviceCopy := service.Copy()
|
||||
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt service data: %w", err)
|
||||
@@ -4839,16 +4840,19 @@ func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Serv
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service) error {
|
||||
serviceCopy := service.Copy()
|
||||
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt service data: %w", err)
|
||||
}
|
||||
|
||||
// Create target type instance outside transaction to avoid variable shadowing
|
||||
targetType := &rpservice.Target{}
|
||||
|
||||
// Use a transaction to ensure atomic updates of the service and its targets
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Delete existing targets
|
||||
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
|
||||
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(targetType).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4869,7 +4873,7 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Serv
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
|
||||
result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID)
|
||||
result := s.db.Delete(&rpservice.Service{}, accountAndIDQueryCondition, accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete service from store")
|
||||
@@ -4882,13 +4886,13 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var service *reverseproxy.Service
|
||||
var service *rpservice.Service
|
||||
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -4906,8 +4910,8 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
||||
var service *reverseproxy.Service
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
|
||||
var service *rpservice.Service
|
||||
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -4925,13 +4929,13 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain str
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var serviceList []*reverseproxy.Service
|
||||
var serviceList []*rpservice.Service
|
||||
result := tx.Find(&serviceList)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||
@@ -4947,13 +4951,13 @@ func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength
|
||||
return serviceList, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var serviceList []*reverseproxy.Service
|
||||
var serviceList []*rpservice.Service
|
||||
result := tx.Find(&serviceList, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||
@@ -5181,13 +5185,13 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces
|
||||
return query
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
|
||||
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var target *reverseproxy.Target
|
||||
var target *rpservice.Target
|
||||
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -5200,3 +5204,65 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
|
||||
|
||||
return target, nil
|
||||
}
|
||||
|
||||
// SaveProxy saves or updates a proxy in the database
|
||||
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
result := s.db.WithContext(ctx).Save(p)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save proxy")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND status = ?", proxyID, "connected").
|
||||
Update("last_seen", time.Now())
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update proxy heartbeat")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
var addresses []string
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
|
||||
Distinct("cluster_address").
|
||||
Pluck("cluster_address", &addresses)
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses")
|
||||
}
|
||||
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
|
||||
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
cutoffTime := time.Now().Add(-inactivityDuration)
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Where("last_seen < ?", cutoffTime).
|
||||
Delete(&proxy.Proxy{})
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to cleanup stale proxies")
|
||||
}
|
||||
|
||||
if result.RowsAffected > 0 {
|
||||
log.WithContext(ctx).Infof("Cleaned up %d stale proxies", result.RowsAffected)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -264,7 +264,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
|
||||
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
|
||||
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
|
||||
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||
&types.AccountOnboarding{}, &reverseproxy.Service{}, &reverseproxy.Target{},
|
||||
&types.AccountOnboarding{}, &service.Service{}, &service.Target{},
|
||||
}
|
||||
|
||||
for i := len(models) - 1; i >= 0; i-- {
|
||||
|
||||
@@ -25,9 +25,10 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -252,13 +253,13 @@ type Store interface {
|
||||
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
|
||||
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
|
||||
|
||||
CreateService(ctx context.Context, service *reverseproxy.Service) error
|
||||
UpdateService(ctx context.Context, service *reverseproxy.Service) error
|
||||
CreateService(ctx context.Context, service *rpservice.Service) error
|
||||
UpdateService(ctx context.Context, service *rpservice.Service) error
|
||||
DeleteService(ctx context.Context, accountID, serviceID string) error
|
||||
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error)
|
||||
GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error)
|
||||
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error)
|
||||
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
|
||||
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
|
||||
GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
|
||||
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
|
||||
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
|
||||
|
||||
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
|
||||
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
|
||||
@@ -270,7 +271,12 @@ type Store interface {
|
||||
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
|
||||
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
|
||||
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
|
||||
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
|
||||
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error)
|
||||
|
||||
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -12,9 +12,10 @@ import (
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
dns "github.com/netbirdio/netbird/dns"
|
||||
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
domain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
zones "github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
records "github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
types "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -150,6 +151,20 @@ func (mr *MockStoreMockRecorder) ApproveAccountPeers(ctx, accountID interface{})
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveAccountPeers", reflect.TypeOf((*MockStore)(nil).ApproveAccountPeers), ctx, accountID)
|
||||
}
|
||||
|
||||
// CleanupStaleProxies mocks base method.
|
||||
func (m *MockStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CleanupStaleProxies", ctx, inactivityDuration)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CleanupStaleProxies indicates an expected call of CleanupStaleProxies.
|
||||
func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockStore) Close(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -293,7 +308,7 @@ func (mr *MockStoreMockRecorder) CreatePolicy(ctx, policy interface{}) *gomock.C
|
||||
}
|
||||
|
||||
// CreateService mocks base method.
|
||||
func (m *MockStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (m *MockStore) CreateService(ctx context.Context, service *service.Service) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateService", ctx, service)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -1095,10 +1110,10 @@ func (mr *MockStoreMockRecorder) GetAccountRoutes(ctx, lockStrength, accountID i
|
||||
}
|
||||
|
||||
// GetAccountServices mocks base method.
|
||||
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAccountServices", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret0, _ := ret[0].([]*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1199,6 +1214,21 @@ func (mr *MockStoreMockRecorder) GetAccountsCounter(ctx interface{}) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsCounter", reflect.TypeOf((*MockStore)(nil).GetAccountsCounter), ctx)
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddresses mocks base method.
|
||||
func (m *MockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddresses", ctx)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddresses indicates an expected call of GetActiveProxyClusterAddresses.
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
// GetAllAccounts mocks base method.
|
||||
func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1813,10 +1843,10 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
|
||||
ret0, _ := ret[0].(*reverseproxy.Service)
|
||||
ret0, _ := ret[0].(*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1828,10 +1858,10 @@ func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain inter
|
||||
}
|
||||
|
||||
// GetServiceByID mocks base method.
|
||||
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByID", ctx, lockStrength, accountID, serviceID)
|
||||
ret0, _ := ret[0].(*reverseproxy.Service)
|
||||
ret0, _ := ret[0].(*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1843,10 +1873,10 @@ func (mr *MockStoreMockRecorder) GetServiceByID(ctx, lockStrength, accountID, se
|
||||
}
|
||||
|
||||
// GetServiceTargetByTargetID mocks base method.
|
||||
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*reverseproxy.Target, error) {
|
||||
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*service.Target, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceTargetByTargetID", ctx, lockStrength, accountID, targetID)
|
||||
ret0, _ := ret[0].(*reverseproxy.Target)
|
||||
ret0, _ := ret[0].(*service.Target)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1858,10 +1888,10 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a
|
||||
}
|
||||
|
||||
// GetServices mocks base method.
|
||||
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
|
||||
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServices", ctx, lockStrength)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret0, _ := ret[0].([]*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -2536,6 +2566,20 @@ func (mr *MockStoreMockRecorder) SavePostureChecks(ctx, postureCheck interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockStore)(nil).SavePostureChecks), ctx, postureCheck)
|
||||
}
|
||||
|
||||
// SaveProxy mocks base method.
|
||||
func (m *MockStore) SaveProxy(ctx context.Context, proxy *proxy.Proxy) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveProxy", ctx, proxy)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveProxy indicates an expected call of SaveProxy.
|
||||
func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
|
||||
}
|
||||
|
||||
// SaveProxyAccessToken mocks base method.
|
||||
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2731,8 +2775,22 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups)
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
|
||||
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID)
|
||||
}
|
||||
|
||||
// UpdateService mocks base method.
|
||||
func (m *MockStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (m *MockStore) UpdateService(ctx context.Context, service *service.Service) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateService", ctx, service)
|
||||
ret0, _ := ret[0].(error)
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -100,7 +100,7 @@ type Account struct {
|
||||
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
||||
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
|
||||
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"`
|
||||
Services []*service.Service `gorm:"foreignKey:AccountID;references:id"`
|
||||
// Settings is a dictionary of Account settings
|
||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
|
||||
@@ -906,7 +906,7 @@ func (a *Account) Copy() *Account {
|
||||
networkResources = append(networkResources, resource.Copy())
|
||||
}
|
||||
|
||||
services := []*reverseproxy.Service{}
|
||||
services := []*service.Service{}
|
||||
for _, service := range a.Services {
|
||||
services = append(services, service.Copy())
|
||||
}
|
||||
@@ -1814,7 +1814,7 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
|
||||
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
|
||||
for _, target := range service.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
@@ -1823,7 +1823,7 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *rever
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) {
|
||||
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
|
||||
port, ok := a.resolveTargetPort(ctx, target)
|
||||
if !ok {
|
||||
return
|
||||
@@ -1840,7 +1840,7 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *revers
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) {
|
||||
func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) {
|
||||
if target.Port != 0 {
|
||||
return target.Port, true
|
||||
}
|
||||
@@ -1856,7 +1856,7 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Ta
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
|
||||
func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
|
||||
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
|
||||
return &Policy{
|
||||
ID: policyID,
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type mockMappingStream struct {
|
||||
grpc.ClientStream
|
||||
messages []*proto.GetMappingUpdateResponse
|
||||
idx int
|
||||
}
|
||||
|
||||
func (m *mockMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
||||
if m.idx >= len(m.messages) {
|
||||
return nil, io.EOF
|
||||
}
|
||||
msg := m.messages[m.idx]
|
||||
m.idx++
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (m *mockMappingStream) Header() (metadata.MD, error) {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
func (m *mockMappingStream) Trailer() metadata.MD { return nil }
|
||||
func (m *mockMappingStream) CloseSend() error { return nil }
|
||||
func (m *mockMappingStream) Context() context.Context { return context.Background() }
|
||||
func (m *mockMappingStream) SendMsg(any) error { return nil }
|
||||
func (m *mockMappingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) {
|
||||
checker := health.NewChecker(nil, nil)
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
healthChecker: checker,
|
||||
}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{InitialSyncComplete: true},
|
||||
},
|
||||
}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, syncDone, "initial sync should be marked done when flag is set")
|
||||
}
|
||||
|
||||
func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) {
|
||||
checker := health.NewChecker(nil, nil)
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
healthChecker: checker,
|
||||
}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{}, // no sync flag
|
||||
},
|
||||
}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, syncDone, "initial sync should not be marked done without flag")
|
||||
}
|
||||
|
||||
func TestHandleMappingStream_NilHealthChecker(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
}
|
||||
|
||||
stream := &mockMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{InitialSyncComplete: true},
|
||||
},
|
||||
}
|
||||
|
||||
syncDone := false
|
||||
err := s.handleMappingStream(context.Background(), stream, &syncDone)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, syncDone, "sync done flag should be set even without health checker")
|
||||
}
|
||||
@@ -1,560 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
proxytypes "github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// integrationTestSetup contains all real components for testing.
|
||||
type integrationTestSetup struct {
|
||||
store store.Store
|
||||
proxyService *nbgrpc.ProxyServiceServer
|
||||
grpcServer *grpc.Server
|
||||
grpcAddr string
|
||||
cleanup func()
|
||||
services []*reverseproxy.Service
|
||||
}
|
||||
|
||||
func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create real SQLite store
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test account
|
||||
testAccount := &types.Account{
|
||||
Id: "test-account-1",
|
||||
Domain: "test.com",
|
||||
DomainCategory: "private",
|
||||
IsDomainPrimaryAccount: true,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
require.NoError(t, testStore.SaveAccount(ctx, testAccount))
|
||||
|
||||
// Generate session keys for reverse proxies
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
pubKey := base64.StdEncoding.EncodeToString(pub)
|
||||
privKey := base64.StdEncoding.EncodeToString(priv)
|
||||
|
||||
// Create test services in the store
|
||||
services := []*reverseproxy.Service{
|
||||
{
|
||||
ID: "rp-1",
|
||||
AccountID: "test-account-1",
|
||||
Name: "Test App 1",
|
||||
Domain: "app1.test.proxy.io",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "10.0.0.1",
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
TargetId: "peer1",
|
||||
TargetType: "peer",
|
||||
Enabled: true,
|
||||
}},
|
||||
Enabled: true,
|
||||
ProxyCluster: "test.proxy.io",
|
||||
SessionPrivateKey: privKey,
|
||||
SessionPublicKey: pubKey,
|
||||
},
|
||||
{
|
||||
ID: "rp-2",
|
||||
AccountID: "test-account-1",
|
||||
Name: "Test App 2",
|
||||
Domain: "app2.test.proxy.io",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "10.0.0.2",
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
TargetId: "peer2",
|
||||
TargetType: "peer",
|
||||
Enabled: true,
|
||||
}},
|
||||
Enabled: true,
|
||||
ProxyCluster: "test.proxy.io",
|
||||
SessionPrivateKey: privKey,
|
||||
SessionPublicKey: pubKey,
|
||||
},
|
||||
}
|
||||
|
||||
for _, svc := range services {
|
||||
require.NoError(t, testStore.CreateService(ctx, svc))
|
||||
}
|
||||
|
||||
// Create real token store
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(5 * time.Minute)
|
||||
|
||||
// Create real users manager
|
||||
usersManager := users.NewManager(testStore)
|
||||
|
||||
// Create real proxy service server with minimal config
|
||||
oidcConfig := nbgrpc.ProxyOIDCConfig{
|
||||
Issuer: "https://fake-issuer.example.com",
|
||||
ClientID: "test-client",
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
}
|
||||
|
||||
proxyService := nbgrpc.NewProxyServiceServer(
|
||||
&testAccessLogManager{},
|
||||
tokenStore,
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
)
|
||||
|
||||
// Use store-backed service manager
|
||||
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
|
||||
proxyService.SetProxyManager(svcMgr)
|
||||
|
||||
// Start real gRPC server
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
proto.RegisterProxyServiceServer(grpcServer, proxyService)
|
||||
|
||||
go func() {
|
||||
if err := grpcServer.Serve(lis); err != nil {
|
||||
t.Logf("gRPC server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return &integrationTestSetup{
|
||||
store: testStore,
|
||||
proxyService: proxyService,
|
||||
grpcServer: grpcServer,
|
||||
grpcAddr: lis.Addr().String(),
|
||||
services: services,
|
||||
cleanup: func() {
|
||||
grpcServer.GracefulStop()
|
||||
cleanup()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// testAccessLogManager provides access log storage for testing.
|
||||
type testAccessLogManager struct{}
|
||||
|
||||
func (m *testAccessLogManager) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *testAccessLogManager) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
|
||||
// noop
|
||||
}
|
||||
|
||||
func (m *testAccessLogManager) StopPeriodicCleanup() {
|
||||
// noop
|
||||
}
|
||||
|
||||
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, _ *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
// storeBackedServiceManager reads directly from the real store.
|
||||
type storeBackedServiceManager struct {
|
||||
store store.Store
|
||||
tokenStore *nbgrpc.OneTimeTokenStore
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
|
||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1")
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetServiceIDByTargetID(ctx context.Context, accountID string, targetID string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
|
||||
setup := setupIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "test-proxy-1",
|
||||
Version: "test-v1",
|
||||
Address: "test.proxy.io",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings from the snapshot - server sends each mapping individually
|
||||
mappingsByID := make(map[string]*proto.ProxyMapping)
|
||||
for i := 0; i < 2; i++ {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
for _, m := range msg.GetMapping() {
|
||||
mappingsByID[m.GetId()] = m
|
||||
}
|
||||
}
|
||||
|
||||
// Should receive 2 mappings total
|
||||
assert.Len(t, mappingsByID, 2, "Should receive 2 reverse proxy mappings")
|
||||
|
||||
rp1 := mappingsByID["rp-1"]
|
||||
require.NotNil(t, rp1)
|
||||
assert.Equal(t, "app1.test.proxy.io", rp1.GetDomain())
|
||||
assert.Equal(t, "test-account-1", rp1.GetAccountId())
|
||||
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, rp1.GetType())
|
||||
assert.NotEmpty(t, rp1.GetAuthToken(), "Should have auth token for peer creation")
|
||||
|
||||
rp2 := mappingsByID["rp-2"]
|
||||
require.NotNil(t, rp2)
|
||||
assert.Equal(t, "app2.test.proxy.io", rp2.GetDomain())
|
||||
}
|
||||
|
||||
func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
|
||||
setup := setupIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
clusterAddress := "test.proxy.io"
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: "test-proxy-cluster",
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings - server sends each mapping individually
|
||||
mappings := make([]*proto.ProxyMapping, 0)
|
||||
for i := 0; i < 2; i++ {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
}
|
||||
|
||||
// Should receive the 2 mappings matching the cluster
|
||||
assert.Len(t, mappings, 2, "Should receive mappings for the cluster")
|
||||
|
||||
for _, mapping := range mappings {
|
||||
t.Logf("Received mapping: id=%s domain=%s", mapping.GetId(), mapping.GetDomain())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) {
|
||||
setup := setupIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
clusterAddress := "test.proxy.io"
|
||||
proxyID := "test-proxy-reconnect"
|
||||
|
||||
// Helper to receive all mappings from a stream
|
||||
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
|
||||
var mappings []*proto.ProxyMapping
|
||||
for i := 0; i < count; i++ {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
mappings = append(mappings, msg.GetMapping()...)
|
||||
}
|
||||
return mappings
|
||||
}
|
||||
|
||||
// First connection
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
firstMappings := receiveMappings(stream1, 2)
|
||||
cancel1()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Second connection (simulating reconnect)
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
secondMappings := receiveMappings(stream2, 2)
|
||||
|
||||
// Should receive the same mappings
|
||||
assert.Equal(t, len(firstMappings), len(secondMappings),
|
||||
"Should receive same number of mappings on reconnect")
|
||||
|
||||
firstIDs := make(map[string]bool)
|
||||
for _, m := range firstMappings {
|
||||
firstIDs[m.GetId()] = true
|
||||
}
|
||||
|
||||
for _, m := range secondMappings {
|
||||
assert.True(t, firstIDs[m.GetId()],
|
||||
"Mapping %s should be present in both connections", m.GetId())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T) {
|
||||
setup := setupIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
// Use real auth middleware and proxy to verify idempotency
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.WarnLevel)
|
||||
|
||||
authMw := auth.NewMiddleware(logger, nil)
|
||||
proxyHandler := proxy.NewReverseProxy(nil, "auto", nil, logger)
|
||||
|
||||
clusterAddress := "test.proxy.io"
|
||||
proxyID := "test-proxy-idempotent"
|
||||
|
||||
var addMappingCalls atomic.Int32
|
||||
|
||||
applyMappings := func(mappings []*proto.ProxyMapping) {
|
||||
for _, mapping := range mappings {
|
||||
if mapping.GetType() == proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED {
|
||||
addMappingCalls.Add(1)
|
||||
|
||||
// Apply to real auth middleware (idempotent)
|
||||
err := authMw.AddDomain(
|
||||
mapping.GetDomain(),
|
||||
nil,
|
||||
"",
|
||||
0,
|
||||
mapping.GetAccountId(),
|
||||
mapping.GetId(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Apply to real proxy (idempotent)
|
||||
proxyHandler.AddMapping(proxy.Mapping{
|
||||
Host: mapping.GetDomain(),
|
||||
ID: mapping.GetId(),
|
||||
AccountID: proxytypes.AccountID(mapping.GetAccountId()),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to receive and apply all mappings
|
||||
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
|
||||
for i := 0; i < 2; i++ {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
applyMappings(msg.GetMapping())
|
||||
}
|
||||
}
|
||||
|
||||
// First connection
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
receiveAndApply(stream1)
|
||||
cancel1()
|
||||
|
||||
firstCallCount := addMappingCalls.Load()
|
||||
t.Logf("First connection: applied %d mappings", firstCallCount)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Second connection
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
receiveAndApply(stream2)
|
||||
cancel2()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Third connection
|
||||
ctx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel3()
|
||||
|
||||
stream3, err := client.GetMappingUpdate(ctx3, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
receiveAndApply(stream3)
|
||||
|
||||
totalCalls := addMappingCalls.Load()
|
||||
t.Logf("After three connections: total applied %d mappings", totalCalls)
|
||||
|
||||
// Should have called addMapping 6 times (2 mappings x 3 connections)
|
||||
// But internal state is NOT duplicated because auth and proxy use maps keyed by domain/host
|
||||
assert.Equal(t, int32(6), totalCalls, "Should have 6 total calls (2 mappings x 3 connections)")
|
||||
}
|
||||
|
||||
func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T) {
|
||||
setup := setupIntegrationTest(t)
|
||||
defer setup.cleanup()
|
||||
|
||||
clusterAddress := "test.proxy.io"
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
receivedByProxy := make(map[string]int)
|
||||
|
||||
for i := 1; i <= 3; i++ {
|
||||
wg.Add(1)
|
||||
go func(proxyNum int) {
|
||||
defer wg.Done()
|
||||
|
||||
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewProxyServiceClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
proxyID := "test-proxy-" + string(rune('A'+proxyNum-1))
|
||||
|
||||
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: proxyID,
|
||||
Version: "test-v1",
|
||||
Address: clusterAddress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Receive all mappings - server sends each mapping individually
|
||||
count := 0
|
||||
for i := 0; i < 2; i++ {
|
||||
msg, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
count += len(msg.GetMapping())
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
receivedByProxy[proxyID] = count
|
||||
mu.Unlock()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
for proxyID, count := range receivedByProxy {
|
||||
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
|
||||
}
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
proxyproto "github.com/pires/go-proxyproto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWrapProxyProtocol_OverridesRemoteAddr(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")},
|
||||
ProxyProtocol: true,
|
||||
}
|
||||
|
||||
raw, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer raw.Close()
|
||||
|
||||
ln := srv.wrapProxyProtocol(raw)
|
||||
|
||||
realClientIP := "203.0.113.50"
|
||||
realClientPort := uint16(54321)
|
||||
|
||||
accepted := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
accepted <- conn
|
||||
}()
|
||||
|
||||
// Connect and send a PROXY v2 header.
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
header := &proxyproto.Header{
|
||||
Version: 2,
|
||||
Command: proxyproto.PROXY,
|
||||
TransportProtocol: proxyproto.TCPv4,
|
||||
SourceAddr: &net.TCPAddr{IP: net.ParseIP(realClientIP), Port: int(realClientPort)},
|
||||
DestinationAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443},
|
||||
}
|
||||
_, err = header.WriteTo(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case accepted := <-accepted:
|
||||
defer accepted.Close()
|
||||
host, _, err := net.SplitHostPort(accepted.RemoteAddr().String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, realClientIP, host, "RemoteAddr should reflect the PROXY header source IP")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_TrustedRequires(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 1234},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.REQUIRE, policy, "trusted source should require PROXY header")
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_UntrustedIgnores(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.TCPAddr{IP: net.ParseIP("203.0.113.50"), Port: 1234},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.IGNORE, policy, "untrusted source should have PROXY header ignored")
|
||||
}
|
||||
|
||||
func TestProxyProtocolPolicy_InvalidIPRejects(t *testing.T) {
|
||||
srv := &Server{
|
||||
Logger: log.StandardLogger(),
|
||||
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
}
|
||||
|
||||
opts := proxyproto.ConnPolicyOptions{
|
||||
Upstream: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
|
||||
}
|
||||
policy, err := srv.proxyProtocolPolicy(opts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, proxyproto.REJECT, policy, "unparsable address should be rejected")
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDebugEndpointDisabledByDefault(t *testing.T) {
|
||||
s := &Server{}
|
||||
assert.False(t, s.DebugEndpointEnabled, "debug endpoint should be disabled by default")
|
||||
}
|
||||
|
||||
func TestDebugEndpointAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty defaults to localhost",
|
||||
input: "",
|
||||
expected: "localhost:8444",
|
||||
},
|
||||
{
|
||||
name: "explicit localhost preserved",
|
||||
input: "localhost:9999",
|
||||
expected: "localhost:9999",
|
||||
},
|
||||
{
|
||||
name: "explicit address preserved",
|
||||
input: "0.0.0.0:8444",
|
||||
expected: "0.0.0.0:8444",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 preserved",
|
||||
input: "127.0.0.1:8444",
|
||||
expected: "127.0.0.1:8444",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := debugEndpointAddr(tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseTrustedProxies(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
want []netip.Prefix
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty string returns nil",
|
||||
raw: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single CIDR",
|
||||
raw: "10.0.0.0/8",
|
||||
want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
},
|
||||
{
|
||||
name: "single bare IPv4",
|
||||
raw: "1.2.3.4",
|
||||
want: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/32")},
|
||||
},
|
||||
{
|
||||
name: "single bare IPv6",
|
||||
raw: "::1",
|
||||
want: []netip.Prefix{netip.MustParsePrefix("::1/128")},
|
||||
},
|
||||
{
|
||||
name: "comma-separated CIDRs",
|
||||
raw: "10.0.0.0/8, 192.168.1.0/24",
|
||||
want: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed CIDRs and bare IPs",
|
||||
raw: "10.0.0.0/8, 1.2.3.4, fd00::/8",
|
||||
want: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("1.2.3.4/32"),
|
||||
netip.MustParsePrefix("fd00::/8"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "whitespace around entries",
|
||||
raw: " 10.0.0.0/8 , 192.168.0.0/16 ",
|
||||
want: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "trailing comma produces no extra entry",
|
||||
raw: "10.0.0.0/8,",
|
||||
want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
},
|
||||
{
|
||||
name: "invalid entry",
|
||||
raw: "not-an-ip",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "partially invalid",
|
||||
raw: "10.0.0.0/8, garbage",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseTrustedProxies(tt.raw)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user