mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 08:54:11 -04:00
[management] Store connected proxies in DB (#5472)
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
This commit is contained in:
@@ -27,21 +27,21 @@ type store interface {
|
||||
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
||||
}
|
||||
|
||||
type proxyURLProvider interface {
|
||||
GetConnectedProxyURLs() []string
|
||||
type proxyManager interface {
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store
|
||||
validator domain.Validator
|
||||
proxyURLProvider proxyURLProvider
|
||||
proxyManager proxyManager
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
|
||||
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager {
|
||||
return Manager{
|
||||
store: store,
|
||||
proxyURLProvider: proxyURLProvider,
|
||||
proxyManager: proxyMgr,
|
||||
validator: domain.Validator{
|
||||
Resolver: net.DefaultResolver,
|
||||
},
|
||||
@@ -67,8 +67,12 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
|
||||
// Add connected proxy clusters as free domains.
|
||||
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||
allowList := m.proxyURLAllowList()
|
||||
log.WithFields(log.Fields{
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"accountID": accountID,
|
||||
"proxyAllowList": allowList,
|
||||
}).Debug("getting domains with proxy allow list")
|
||||
@@ -107,7 +111,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
||||
}
|
||||
|
||||
// Verify the target cluster is in the available clusters
|
||||
allowList := m.proxyURLAllowList()
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
clusterValid := false
|
||||
for _, cluster := range allowList {
|
||||
if cluster == targetCluster {
|
||||
@@ -221,25 +228,26 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
|
||||
}
|
||||
}
|
||||
|
||||
// GetClusterDomains returns a list of proxy cluster domains.
|
||||
func (m Manager) GetClusterDomains() []string {
|
||||
return m.proxyURLAllowList()
|
||||
}
|
||||
|
||||
// proxyURLAllowList retrieves a list of currently connected proxies and
|
||||
// their URLs
|
||||
func (m Manager) proxyURLAllowList() []string {
|
||||
var reverseProxyAddresses []string
|
||||
if m.proxyURLProvider != nil {
|
||||
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
|
||||
if m.proxyManager == nil {
|
||||
return nil
|
||||
}
|
||||
return reverseProxyAddresses
|
||||
addresses, err := m.proxyManager.GetActiveClusterAddresses(context.Background())
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
|
||||
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
|
||||
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||
allowList := m.proxyURLAllowList()
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
if len(allowList) == 0 {
|
||||
return "", fmt.Errorf("no proxy clusters available")
|
||||
}
|
||||
|
||||
36
management/internals/modules/reverseproxy/proxy/manager.go
Normal file
36
management/internals/modules/reverseproxy/proxy/manager.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package proxy
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package proxy -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
Disconnect(ctx context.Context, proxyID string) error
|
||||
Heartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
|
||||
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
|
||||
type OIDCValidationConfig struct {
|
||||
Issuer string
|
||||
Audiences []string
|
||||
KeysLocation string
|
||||
MaxTokenAgeSeconds int64
|
||||
}
|
||||
|
||||
// Controller is responsible for managing proxy clusters and routing service updates.
|
||||
type Controller interface {
|
||||
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
|
||||
GetOIDCValidationConfig() OIDCValidationConfig
|
||||
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||
GetProxiesForCluster(clusterAddr string) []string
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// GRPCController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
|
||||
type GRPCController struct {
|
||||
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||
// Map of cluster address -> set of proxy IDs
|
||||
clusterProxies sync.Map
|
||||
metrics *metrics
|
||||
}
|
||||
|
||||
// NewGRPCController creates a new GRPCController.
|
||||
func NewGRPCController(proxyGRPCServer *nbgrpc.ProxyServiceServer, meter metric.Meter) (*GRPCController, error) {
|
||||
m, err := newMetrics(meter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GRPCController{
|
||||
proxyGRPCServer: proxyGRPCServer,
|
||||
metrics: m,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
|
||||
func (c *GRPCController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
||||
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
|
||||
c.metrics.IncrementServiceUpdateSendCount(clusterAddr)
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
|
||||
func (c *GRPCController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
|
||||
return c.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
}
|
||||
|
||||
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
|
||||
func (c *GRPCController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
if clusterAddr == "" {
|
||||
return nil
|
||||
}
|
||||
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
|
||||
|
||||
c.metrics.IncrementProxyConnectionCount(clusterAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnregisterProxyFromCluster removes a proxy from a cluster.
|
||||
func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
if clusterAddr == "" {
|
||||
return nil
|
||||
}
|
||||
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
|
||||
proxySet.(*sync.Map).Delete(proxyID)
|
||||
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
|
||||
|
||||
c.metrics.DecrementProxyConnectionCount(clusterAddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
|
||||
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
proxySet, ok := c.clusterProxies.Load(clusterAddr)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var proxies []string
|
||||
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
||||
proxies = append(proxies, key.(string))
|
||||
return true
|
||||
})
|
||||
return proxies
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
// store defines the interface for proxy persistence operations
|
||||
type store interface {
|
||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
|
||||
// Manager handles all proxy operations
|
||||
type Manager struct {
|
||||
store store
|
||||
metrics *metrics
|
||||
}
|
||||
|
||||
// NewManager creates a new proxy Manager
|
||||
func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
m, err := newMetrics(meter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Manager{
|
||||
store: store,
|
||||
metrics: m,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Connect registers a new proxy connection in the database
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
"clusterAddress": clusterAddress,
|
||||
"ipAddress": ipAddress,
|
||||
}).Info("proxy connected")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
Status: "disconnected",
|
||||
DisconnectedAt: &now,
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
}).Info("proxy disconnected")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp
|
||||
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
m.metrics.IncrementProxyHeartbeatCount()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
type metrics struct {
|
||||
proxyConnectionCount metric.Int64UpDownCounter
|
||||
serviceUpdateSendCount metric.Int64Counter
|
||||
proxyHeartbeatCount metric.Int64Counter
|
||||
}
|
||||
|
||||
func newMetrics(meter metric.Meter) (*metrics, error) {
|
||||
proxyConnectionCount, err := meter.Int64UpDownCounter(
|
||||
"management_proxy_connection_count",
|
||||
metric.WithDescription("Number of active proxy connections"),
|
||||
metric.WithUnit("{connection}"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serviceUpdateSendCount, err := meter.Int64Counter(
|
||||
"management_proxy_service_update_send_count",
|
||||
metric.WithDescription("Total number of service updates sent to proxies"),
|
||||
metric.WithUnit("{update}"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxyHeartbeatCount, err := meter.Int64Counter(
|
||||
"management_proxy_heartbeat_count",
|
||||
metric.WithDescription("Total number of proxy heartbeats received"),
|
||||
metric.WithUnit("{heartbeat}"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &metrics{
|
||||
proxyConnectionCount: proxyConnectionCount,
|
||||
serviceUpdateSendCount: serviceUpdateSendCount,
|
||||
proxyHeartbeatCount: proxyHeartbeatCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *metrics) IncrementProxyConnectionCount(clusterAddr string) {
|
||||
m.proxyConnectionCount.Add(context.Background(), 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("cluster", clusterAddr),
|
||||
))
|
||||
}
|
||||
|
||||
func (m *metrics) DecrementProxyConnectionCount(clusterAddr string) {
|
||||
m.proxyConnectionCount.Add(context.Background(), -1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("cluster", clusterAddr),
|
||||
))
|
||||
}
|
||||
|
||||
func (m *metrics) IncrementServiceUpdateSendCount(clusterAddr string) {
|
||||
m.serviceUpdateSendCount.Add(context.Background(), 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("cluster", clusterAddr),
|
||||
))
|
||||
}
|
||||
|
||||
func (m *metrics) IncrementProxyHeartbeatCount() {
|
||||
m.proxyHeartbeatCount.Add(context.Background(), 1)
|
||||
}
|
||||
199
management/internals/modules/reverseproxy/proxy/manager_mock.go
Normal file
199
management/internals/modules/reverseproxy/proxy/manager_mock.go
Normal file
@@ -0,0 +1,199 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./manager.go
|
||||
|
||||
// Package proxy is a generated GoMock package.
|
||||
package proxy
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
proto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// MockManager is a mock of Manager interface.
|
||||
type MockManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockManagerMockRecorder is the mock recorder for MockManager.
|
||||
type MockManagerMockRecorder struct {
|
||||
mock *MockManager
|
||||
}
|
||||
|
||||
// NewMockManager creates a new mock instance.
|
||||
func NewMockManager(ctrl *gomock.Controller) *MockManager {
|
||||
mock := &MockManager{ctrl: ctrl}
|
||||
mock.recorder = &MockManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CleanupStale mocks base method.
|
||||
func (m *MockManager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CleanupStale", ctx, inactivityDuration)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CleanupStale indicates an expected call of CleanupStale.
|
||||
func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Disconnect indicates an expected call of Disconnect.
|
||||
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses mocks base method.
|
||||
func (m *MockManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusterAddresses", ctx)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses indicates an expected call of GetActiveClusterAddresses.
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
// Heartbeat mocks base method.
|
||||
func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Heartbeat indicates an expected call of Heartbeat.
|
||||
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID)
|
||||
}
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
type MockController struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockControllerMockRecorder
|
||||
}
|
||||
|
||||
// MockControllerMockRecorder is the mock recorder for MockController.
|
||||
type MockControllerMockRecorder struct {
|
||||
mock *MockController
|
||||
}
|
||||
|
||||
// NewMockController creates a new mock instance.
|
||||
func NewMockController(ctrl *gomock.Controller) *MockController {
|
||||
mock := &MockController{ctrl: ctrl}
|
||||
mock.recorder = &MockControllerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockController) EXPECT() *MockControllerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig mocks base method.
|
||||
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
|
||||
ret0, _ := ret[0].(OIDCValidationConfig)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
|
||||
func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig))
|
||||
}
|
||||
|
||||
// GetProxiesForCluster mocks base method.
|
||||
func (m *MockController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
|
||||
func (mr *MockControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockController)(nil).GetProxiesForCluster), clusterAddr)
|
||||
}
|
||||
|
||||
// RegisterProxyToCluster mocks base method.
|
||||
func (m *MockController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
|
||||
func (mr *MockControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster mocks base method.
|
||||
func (m *MockController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SendServiceUpdateToCluster", ctx, accountID, update, clusterAddr)
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
|
||||
func (mr *MockControllerMockRecorder) SendServiceUpdateToCluster(ctx, accountID, update, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockController)(nil).SendServiceUpdateToCluster), ctx, accountID, update, clusterAddr)
|
||||
}
|
||||
|
||||
// UnregisterProxyFromCluster mocks base method.
|
||||
func (m *MockController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
|
||||
func (mr *MockControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
|
||||
}
|
||||
20
management/internals/modules/reverseproxy/proxy/proxy.go
Normal file
20
management/internals/modules/reverseproxy/proxy/proxy.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package proxy
|
||||
|
||||
import "time"
|
||||
|
||||
// Proxy represents a reverse proxy instance
|
||||
type Proxy struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||
IPAddress string `gorm:"type:varchar(45)"`
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (Proxy) TableName() string {
|
||||
return "proxies"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package reverseproxy
|
||||
package service
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
//go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,7 +14,7 @@ type Manager interface {
|
||||
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
||||
DeleteAllServices(ctx context.Context, accountID, userID string) error
|
||||
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
||||
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
|
||||
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
|
||||
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
||||
ReloadService(ctx context.Context, accountID, serviceID string) error
|
||||
GetGlobalServices(ctx context.Context) ([]*Service, error)
|
||||
@@ -1,8 +1,8 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./interface.go
|
||||
|
||||
// Package reverseproxy is a generated GoMock package.
|
||||
package reverseproxy
|
||||
// Package service is a generated GoMock package.
|
||||
package service
|
||||
|
||||
import (
|
||||
context "context"
|
||||
@@ -239,7 +239,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic
|
||||
}
|
||||
|
||||
// SetStatus mocks base method.
|
||||
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
|
||||
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -6,10 +6,10 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
@@ -17,11 +17,11 @@ import (
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager reverseproxy.Manager
|
||||
manager rpservice.Manager
|
||||
}
|
||||
|
||||
// RegisterEndpoints registers all service HTTP endpoints.
|
||||
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
||||
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
service := new(reverseproxy.Service)
|
||||
service := new(rpservice.Service)
|
||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||
|
||||
if err = service.Validate(); err != nil {
|
||||
@@ -130,7 +130,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
service := new(reverseproxy.Service)
|
||||
service := new(rpservice.Service)
|
||||
service.ID = serviceID
|
||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||
|
||||
@@ -27,7 +27,7 @@ type trackedExpose struct {
|
||||
type exposeTracker struct {
|
||||
activeExposes sync.Map
|
||||
exposeCreateMu sync.Mutex
|
||||
manager *managerImpl
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
func exposeKey(peerID, domain string) string {
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
)
|
||||
|
||||
func TestExposeKey(t *testing.T) {
|
||||
@@ -120,7 +120,7 @@ func TestReapExpiredExposes(t *testing.T) {
|
||||
tracker := mgr.exposeTracker
|
||||
|
||||
ctx := context.Background()
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
@@ -156,7 +156,7 @@ func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) {
|
||||
tracker := mgr.exposeTracker
|
||||
|
||||
ctx := context.Background()
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
@@ -191,7 +191,7 @@ func TestConcurrentTrackAndCount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
for i := range 5 {
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080 + i,
|
||||
Protocol: "http",
|
||||
})
|
||||
@@ -11,17 +11,15 @@ import (
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -33,24 +31,22 @@ type ClusterDeriver interface {
|
||||
GetClusterDomains() []string
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
type Manager struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
settingsManager settings.Manager
|
||||
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||
proxyController proxy.Controller
|
||||
clusterDeriver ClusterDeriver
|
||||
exposeTracker *exposeTracker
|
||||
}
|
||||
|
||||
// NewManager creates a new service manager.
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, settingsManager settings.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
|
||||
mgr := &managerImpl{
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
|
||||
mgr := &Manager{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
settingsManager: settingsManager,
|
||||
proxyGRPCServer: proxyGRPCServer,
|
||||
proxyController: proxyController,
|
||||
clusterDeriver: clusterDeriver,
|
||||
}
|
||||
mgr.exposeTracker = &exposeTracker{manager: mgr}
|
||||
@@ -58,11 +54,11 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
|
||||
}
|
||||
|
||||
// StartExposeReaper delegates to the expose tracker.
|
||||
func (m *managerImpl) StartExposeReaper(ctx context.Context) {
|
||||
func (m *Manager) StartExposeReaper(ctx context.Context) {
|
||||
m.exposeTracker.StartExposeReaper(ctx)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -86,34 +82,34 @@ func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID stri
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
for _, target := range service.Targets {
|
||||
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
|
||||
for _, target := range s.Targets {
|
||||
switch target.TargetType {
|
||||
case reverseproxy.TargetTypePeer:
|
||||
case service.TargetTypePeer:
|
||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = peer.IP.String()
|
||||
case reverseproxy.TargetTypeHost:
|
||||
case service.TargetTypeHost:
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = resource.Prefix.Addr().String()
|
||||
case reverseproxy.TargetTypeDomain:
|
||||
case service.TargetTypeDomain:
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = resource.Domain
|
||||
case reverseproxy.TargetTypeSubnet:
|
||||
case service.TargetTypeSubnet:
|
||||
// For subnets we do not do any lookups on the resource
|
||||
default:
|
||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||
@@ -122,7 +118,7 @@ func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -143,7 +139,7 @@ func (m *managerImpl) GetService(ctx context.Context, accountID, userID, service
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -152,29 +148,29 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.persistNewService(ctx, accountID, service); err != nil {
|
||||
if err := m.persistNewService(ctx, accountID, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
|
||||
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
|
||||
if m.clusterDeriver != nil {
|
||||
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
@@ -201,7 +197,7 @@ func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
||||
return err
|
||||
@@ -219,7 +215,7 @@ func (m *managerImpl) persistNewService(ctx context.Context, accountID string, s
|
||||
})
|
||||
}
|
||||
|
||||
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
@@ -235,7 +231,7 @@ func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction stor
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -259,7 +255,7 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdateNotifications(service, updateInfo)
|
||||
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
@@ -271,7 +267,7 @@ type serviceUpdateInfo struct {
|
||||
serviceEnabledChanged bool
|
||||
}
|
||||
|
||||
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
|
||||
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -309,7 +305,7 @@ func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string
|
||||
return &updateInfo, err
|
||||
}
|
||||
|
||||
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -326,7 +322,7 @@ func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
|
||||
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||
service.Auth.PasswordAuth.Password == "" {
|
||||
@@ -340,54 +336,40 @@ func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reve
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
|
||||
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
|
||||
service.Meta = existingService.Meta
|
||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||
service.SessionPublicKey = existingService.SessionPublicKey
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
|
||||
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
switch {
|
||||
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
|
||||
m.sendServiceUpdate(service, reverseproxy.Delete, updateInfo.oldCluster, "")
|
||||
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
|
||||
case !service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
|
||||
case service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
|
||||
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
|
||||
case !s.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
|
||||
case s.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
|
||||
default:
|
||||
m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "")
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendServiceUpdate(service *reverseproxy.Service, operation reverseproxy.Operation, cluster, oldService string) {
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
mapping := service.ToProtoMapping(operation, oldService, oidcCfg)
|
||||
m.sendMappingsToCluster([]*proto.ProxyMapping{mapping}, cluster)
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendMappingsToCluster(mappings []*proto.ProxyMapping, cluster string) {
|
||||
if len(mappings) == 0 {
|
||||
return
|
||||
}
|
||||
update := &proto.GetMappingUpdateResponse{
|
||||
Mapping: mappings,
|
||||
}
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(update, cluster)
|
||||
}
|
||||
|
||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
|
||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
|
||||
for _, target := range targets {
|
||||
switch target.TargetType {
|
||||
case reverseproxy.TargetTypePeer:
|
||||
case service.TargetTypePeer:
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
||||
}
|
||||
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
||||
}
|
||||
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
|
||||
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
|
||||
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
||||
@@ -399,7 +381,7 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
@@ -408,9 +390,10 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var service *reverseproxy.Service
|
||||
var s *service.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
var err error
|
||||
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -429,20 +412,20 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
|
||||
return err
|
||||
}
|
||||
|
||||
if service.Source == reverseproxy.SourceEphemeral {
|
||||
m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain)
|
||||
if s.Source == service.SourceEphemeral {
|
||||
m.exposeTracker.UntrackExpose(s.SourcePeer, s.Domain)
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
|
||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
|
||||
|
||||
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
@@ -451,16 +434,16 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var services []*reverseproxy.Service
|
||||
var services []*service.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
services, err = transaction.GetServicesByAccountID(ctx, store.LockingStrengthUpdate, accountID)
|
||||
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
if err = transaction.DeleteService(ctx, accountID, service.ID); err != nil {
|
||||
for _, svc := range services {
|
||||
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -471,20 +454,14 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
|
||||
return err
|
||||
}
|
||||
|
||||
clusterMappings := make(map[string][]*proto.ProxyMapping)
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
for _, service := range services {
|
||||
if service.Source == reverseproxy.SourceEphemeral {
|
||||
m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain)
|
||||
for _, svc := range services {
|
||||
if svc.Source == service.SourceEphemeral {
|
||||
m.exposeTracker.UntrackExpose(svc.SourcePeer, svc.Domain)
|
||||
}
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, service.EventMeta())
|
||||
mapping := service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg)
|
||||
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
|
||||
}
|
||||
|
||||
for cluster, mappings := range clusterMappings {
|
||||
m.sendMappingsToCluster(mappings, cluster)
|
||||
m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta())
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
@@ -494,7 +471,7 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
|
||||
|
||||
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
|
||||
// Call this when receiving a gRPC notification that the certificate was issued.
|
||||
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
@@ -513,7 +490,7 @@ func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, ser
|
||||
}
|
||||
|
||||
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
||||
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
|
||||
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
@@ -530,50 +507,42 @@ func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string
|
||||
})
|
||||
}
|
||||
|
||||
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||
s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "")
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||
func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
clusterMappings := make(map[string][]*proto.ProxyMapping)
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
for _, s := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
mapping := service.ToProtoMapping(reverseproxy.Update, "", oidcCfg)
|
||||
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
|
||||
}
|
||||
|
||||
for cluster, mappings := range clusterMappings {
|
||||
m.sendMappingsToCluster(mappings, cluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
@@ -589,7 +558,7 @@ func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Se
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||
@@ -603,7 +572,7 @@ func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID s
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
@@ -619,7 +588,7 @@ func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string)
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
@@ -637,7 +606,7 @@ func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID stri
|
||||
|
||||
// validateExposePermission checks whether the peer is allowed to use the expose feature.
|
||||
// It verifies the account has peer expose enabled and that the peer belongs to an allowed group.
|
||||
func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, peerID string) error {
|
||||
func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerID string) error {
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||
@@ -670,7 +639,7 @@ func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, p
|
||||
// CreateServiceFromPeer creates a service initiated by a peer expose request.
|
||||
// It validates the request, checks expose permissions, enforces the per-peer limit,
|
||||
// creates the service, and tracks it for TTL-based reaping.
|
||||
func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
|
||||
func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||
if err := req.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err)
|
||||
}
|
||||
@@ -679,31 +648,31 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serviceName, err := reverseproxy.GenerateExposeName(req.NamePrefix)
|
||||
serviceName, err := service.GenerateExposeName(req.NamePrefix)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err)
|
||||
}
|
||||
|
||||
service := req.ToService(accountID, peerID, serviceName)
|
||||
service.Source = reverseproxy.SourceEphemeral
|
||||
svc := req.ToService(accountID, peerID, serviceName)
|
||||
svc.Source = service.SourceEphemeral
|
||||
|
||||
if service.Domain == "" {
|
||||
domain, err := m.buildRandomDomain(service.Name)
|
||||
if svc.Domain == "" {
|
||||
domain, err := m.buildRandomDomain(svc.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build random domain for service %s: %w", service.Name, err)
|
||||
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
|
||||
}
|
||||
service.Domain = domain
|
||||
svc.Domain = domain
|
||||
}
|
||||
|
||||
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
|
||||
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, service.Auth.BearerAuth.DistributionGroups)
|
||||
if svc.Auth.BearerAuth != nil && svc.Auth.BearerAuth.Enabled {
|
||||
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, svc.Auth.BearerAuth.DistributionGroups)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group ids for service %s: %w", service.Name, err)
|
||||
return nil, fmt.Errorf("get group ids for service %s: %w", svc.Name, err)
|
||||
}
|
||||
service.Auth.BearerAuth.DistributionGroups = groupIDs
|
||||
svc.Auth.BearerAuth.DistributionGroups = groupIDs
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, svc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -713,45 +682,45 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
service.Meta.LastRenewedAt = &now
|
||||
service.SourcePeer = peerID
|
||||
svc.Meta.LastRenewedAt = &now
|
||||
svc.SourcePeer = peerID
|
||||
|
||||
if err := m.persistNewService(ctx, accountID, service); err != nil {
|
||||
if err := m.persistNewService(ctx, accountID, svc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, service.Domain, accountID)
|
||||
alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, svc.Domain, accountID)
|
||||
if alreadyTracked {
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", service.Domain, err)
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", svc.Domain, err)
|
||||
}
|
||||
return nil, status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
|
||||
}
|
||||
if !allowed {
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", service.Domain, err)
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", svc.Domain, err)
|
||||
}
|
||||
return nil, status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
|
||||
}
|
||||
|
||||
meta := addPeerInfoToEventMeta(service.EventMeta(), peer)
|
||||
m.accountManager.StoreEvent(ctx, peerID, service.ID, accountID, activity.PeerServiceExposed, meta)
|
||||
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||
m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta)
|
||||
|
||||
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
|
||||
return nil, fmt.Errorf("replace host by lookup for service %s: %w", service.ID, err)
|
||||
if err := m.replaceHostByLookup(ctx, accountID, svc); err != nil {
|
||||
return nil, fmt.Errorf("replace host by lookup for service %s: %w", svc.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return &reverseproxy.ExposeServiceResponse{
|
||||
ServiceName: service.Name,
|
||||
ServiceURL: "https://" + service.Domain,
|
||||
Domain: service.Domain,
|
||||
return &service.ExposeServiceResponse{
|
||||
ServiceName: svc.Name,
|
||||
ServiceURL: "https://" + svc.Domain,
|
||||
Domain: svc.Domain,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
|
||||
func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
|
||||
if len(groupNames) == 0 {
|
||||
return []string{}, fmt.Errorf("no group names provided")
|
||||
}
|
||||
@@ -766,7 +735,7 @@ func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) buildRandomDomain(name string) (string, error) {
|
||||
func (m *Manager) buildRandomDomain(name string) (string, error) {
|
||||
if m.clusterDeriver == nil {
|
||||
return "", fmt.Errorf("unable to get random domain")
|
||||
}
|
||||
@@ -781,7 +750,7 @@ func (m *managerImpl) buildRandomDomain(name string) (string, error) {
|
||||
|
||||
// RenewServiceFromPeer renews the in-memory TTL tracker for the peer's expose session.
|
||||
// Returns an error if the expose is not actively tracked.
|
||||
func (m *managerImpl) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error {
|
||||
func (m *Manager) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error {
|
||||
if !m.exposeTracker.RenewTrackedExpose(peerID, domain) {
|
||||
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
|
||||
}
|
||||
@@ -789,7 +758,7 @@ func (m *managerImpl) RenewServiceFromPeer(_ context.Context, _, peerID, domain
|
||||
}
|
||||
|
||||
// StopServiceFromPeer stops a peer's active expose session by untracking and deleting the service.
|
||||
func (m *managerImpl) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
|
||||
return err
|
||||
@@ -804,8 +773,8 @@ func (m *managerImpl) StopServiceFromPeer(ctx context.Context, accountID, peerID
|
||||
|
||||
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
|
||||
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
|
||||
func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
|
||||
service, err := m.lookupPeerService(ctx, accountID, peerID, domain)
|
||||
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
|
||||
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -814,41 +783,41 @@ func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peer
|
||||
if expired {
|
||||
activityCode = activity.PeerServiceExposeExpired
|
||||
}
|
||||
return m.deletePeerService(ctx, accountID, peerID, service.ID, activityCode)
|
||||
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
|
||||
}
|
||||
|
||||
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
|
||||
func (m *managerImpl) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*reverseproxy.Service, error) {
|
||||
service, err := m.store.GetServiceByDomain(ctx, accountID, domain)
|
||||
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
|
||||
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if service.Source != reverseproxy.SourceEphemeral {
|
||||
if svc.Source != service.SourceEphemeral {
|
||||
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
|
||||
}
|
||||
|
||||
if service.SourcePeer != peerID {
|
||||
if svc.SourcePeer != peerID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
|
||||
}
|
||||
|
||||
return service, nil
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
|
||||
var service *reverseproxy.Service
|
||||
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
|
||||
var svc *service.Service
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if service.Source != reverseproxy.SourceEphemeral {
|
||||
if svc.Source != service.SourceEphemeral {
|
||||
return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose")
|
||||
}
|
||||
|
||||
if service.SourcePeer != peerID {
|
||||
if svc.SourcePeer != peerID {
|
||||
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
|
||||
}
|
||||
|
||||
@@ -868,11 +837,11 @@ func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID,
|
||||
peer = nil
|
||||
}
|
||||
|
||||
meta := addPeerInfoToEventMeta(service.EventMeta(), peer)
|
||||
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||
|
||||
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta)
|
||||
|
||||
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
@@ -10,21 +10,21 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -33,13 +33,13 @@ func TestInitializeServiceForCreate(t *testing.T) {
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
|
||||
mgr := &managerImpl{
|
||||
mgr := &Manager{
|
||||
clusterDeriver: nil,
|
||||
}
|
||||
|
||||
service := &reverseproxy.Service{
|
||||
service := &rpservice.Service{
|
||||
Domain: "example.com",
|
||||
Auth: reverseproxy.AuthConfig{},
|
||||
Auth: rpservice.AuthConfig{},
|
||||
}
|
||||
|
||||
err := mgr.initializeServiceForCreate(ctx, accountID, service)
|
||||
@@ -53,12 +53,12 @@ func TestInitializeServiceForCreate(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("verifies session keys are different", func(t *testing.T) {
|
||||
mgr := &managerImpl{
|
||||
mgr := &Manager{
|
||||
clusterDeriver: nil,
|
||||
}
|
||||
|
||||
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
|
||||
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
|
||||
service1 := &rpservice.Service{Domain: "test1.com", Auth: rpservice.AuthConfig{}}
|
||||
service2 := &rpservice.Service{Domain: "test2.com", Auth: rpservice.AuthConfig{}}
|
||||
|
||||
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
|
||||
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
|
||||
@@ -100,7 +100,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
||||
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
errorType: status.AlreadyExists,
|
||||
@@ -112,7 +112,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
@@ -123,7 +123,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
errorType: status.AlreadyExists,
|
||||
@@ -149,7 +149,7 @@ func TestCheckDomainAvailable(t *testing.T) {
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
tt.setupMock(mockStore)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
|
||||
|
||||
if tt.expectedError {
|
||||
@@ -179,7 +179,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
GetServiceByDomain(ctx, accountID, "").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
|
||||
mgr := &managerImpl{}
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
@@ -192,9 +192,9 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "test.com").
|
||||
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
|
||||
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
|
||||
|
||||
assert.Error(t, err)
|
||||
@@ -212,7 +212,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
GetServiceByDomain(ctx, accountID, "nil.com").
|
||||
Return(nil, nil)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
mgr := &Manager{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
@@ -228,10 +228,10 @@ func TestPersistNewService(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
service := &reverseproxy.Service{
|
||||
service := &rpservice.Service{
|
||||
ID: "service-123",
|
||||
Domain: "new.com",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
Targets: []*rpservice.Target{},
|
||||
}
|
||||
|
||||
// Mock ExecuteInTransaction to execute the function immediately
|
||||
@@ -250,7 +250,7 @@ func TestPersistNewService(t *testing.T) {
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
mgr := &Manager{store: mockStore}
|
||||
err := mgr.persistNewService(ctx, accountID, service)
|
||||
|
||||
assert.NoError(t, err)
|
||||
@@ -261,10 +261,10 @@ func TestPersistNewService(t *testing.T) {
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
service := &reverseproxy.Service{
|
||||
service := &rpservice.Service{
|
||||
ID: "service-123",
|
||||
Domain: "existing.com",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
Targets: []*rpservice.Target{},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
@@ -273,12 +273,12 @@ func TestPersistNewService(t *testing.T) {
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "existing.com").
|
||||
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
||||
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
||||
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
mgr := &Manager{store: mockStore}
|
||||
err := mgr.persistNewService(ctx, accountID, service)
|
||||
|
||||
require.Error(t, err)
|
||||
@@ -288,21 +288,21 @@ func TestPersistNewService(t *testing.T) {
|
||||
})
|
||||
}
|
||||
func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||
mgr := &managerImpl{}
|
||||
mgr := &Manager{}
|
||||
|
||||
t.Run("preserve password when empty", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
existing := &rpservice.Service{
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "hashed-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
updated := &rpservice.Service{
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "",
|
||||
},
|
||||
@@ -315,18 +315,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("preserve pin when empty", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PinAuth: &reverseproxy.PINAuthConfig{
|
||||
existing := &rpservice.Service{
|
||||
Auth: rpservice.AuthConfig{
|
||||
PinAuth: &rpservice.PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "hashed-pin",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PinAuth: &reverseproxy.PINAuthConfig{
|
||||
updated := &rpservice.Service{
|
||||
Auth: rpservice.AuthConfig{
|
||||
PinAuth: &rpservice.PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "",
|
||||
},
|
||||
@@ -339,18 +339,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("do not preserve when password is provided", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
existing := &rpservice.Service{
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "old-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
updated := &rpservice.Service{
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "new-password",
|
||||
},
|
||||
@@ -365,10 +365,10 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPreserveServiceMetadata(t *testing.T) {
|
||||
mgr := &managerImpl{}
|
||||
mgr := &Manager{}
|
||||
|
||||
existing := &reverseproxy.Service{
|
||||
Meta: reverseproxy.ServiceMeta{
|
||||
existing := &rpservice.Service{
|
||||
Meta: rpservice.Meta{
|
||||
CertificateIssuedAt: func() *time.Time { t := time.Now(); return &t }(),
|
||||
Status: "active",
|
||||
},
|
||||
@@ -376,7 +376,7 @@ func TestPreserveServiceMetadata(t *testing.T) {
|
||||
SessionPublicKey: "public-key",
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
updated := &rpservice.Service{
|
||||
Domain: "updated.com",
|
||||
}
|
||||
|
||||
@@ -400,31 +400,32 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
IP: net.ParseIP("100.64.0.1"),
|
||||
}
|
||||
|
||||
newEphemeralService := func() *reverseproxy.Service {
|
||||
return &reverseproxy.Service{
|
||||
newEphemeralService := func() *rpservice.Service {
|
||||
return &rpservice.Service{
|
||||
ID: serviceID,
|
||||
AccountID: accountID,
|
||||
Name: "test-service",
|
||||
Domain: "test.example.com",
|
||||
Source: reverseproxy.SourceEphemeral,
|
||||
Source: rpservice.SourceEphemeral,
|
||||
SourcePeer: ownerPeerID,
|
||||
}
|
||||
}
|
||||
|
||||
newPermanentService := func() *reverseproxy.Service {
|
||||
return &reverseproxy.Service{
|
||||
newPermanentService := func() *rpservice.Service {
|
||||
return &rpservice.Service{
|
||||
ID: serviceID,
|
||||
AccountID: accountID,
|
||||
Name: "api-service",
|
||||
Domain: "api.example.com",
|
||||
Source: reverseproxy.SourcePermanent,
|
||||
Source: rpservice.SourcePermanent,
|
||||
}
|
||||
}
|
||||
|
||||
newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer {
|
||||
t.Helper()
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Hour)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil)
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
t.Cleanup(srv.Close)
|
||||
return srv
|
||||
}
|
||||
@@ -458,10 +459,14 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
|
||||
Return(testPeer, nil)
|
||||
|
||||
mgr := &managerImpl{
|
||||
mgr := &Manager{
|
||||
store: mockStore,
|
||||
accountManager: mockAccountMgr,
|
||||
proxyGRPCServer: newProxyServer(t),
|
||||
proxyController: func() proxy.Controller {
|
||||
c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
return c
|
||||
}(),
|
||||
}
|
||||
|
||||
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed)
|
||||
@@ -485,7 +490,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{
|
||||
mgr := &Manager{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
@@ -514,7 +519,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{
|
||||
mgr := &Manager{
|
||||
store: mockStore,
|
||||
}
|
||||
|
||||
@@ -556,10 +561,14 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
|
||||
Return(testPeer, nil)
|
||||
|
||||
mgr := &managerImpl{
|
||||
mgr := &Manager{
|
||||
store: mockStore,
|
||||
accountManager: mockAccountMgr,
|
||||
proxyGRPCServer: newProxyServer(t),
|
||||
proxyController: func() proxy.Controller {
|
||||
c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
return c
|
||||
}(),
|
||||
}
|
||||
|
||||
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceExposeExpired)
|
||||
@@ -596,10 +605,14 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
|
||||
Return(testPeer, nil)
|
||||
|
||||
mgr := &managerImpl{
|
||||
mgr := &Manager{
|
||||
store: mockStore,
|
||||
accountManager: mockAccountMgr,
|
||||
proxyGRPCServer: newProxyServer(t),
|
||||
proxyController: func() proxy.Controller {
|
||||
c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
return c
|
||||
}(),
|
||||
}
|
||||
|
||||
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed)
|
||||
@@ -612,19 +625,6 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// noopExtraSettings is a minimal extra_settings.Manager for tests without external integrations.
|
||||
type noopExtraSettings struct{}
|
||||
|
||||
func (n *noopExtraSettings) GetExtraSettings(_ context.Context, _ string) (*types.ExtraSettings, error) {
|
||||
return &types.ExtraSettings{}, nil
|
||||
}
|
||||
|
||||
func (n *noopExtraSettings) UpdateExtraSettings(_ context.Context, _, _ string, _ *types.ExtraSettings) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var _ extra_settings.Manager = (*noopExtraSettings)(nil)
|
||||
|
||||
// testClusterDeriver is a minimal ClusterDeriver that returns a fixed domain list.
|
||||
type testClusterDeriver struct {
|
||||
domains []string
|
||||
@@ -646,7 +646,7 @@ const (
|
||||
)
|
||||
|
||||
// setupIntegrationTest creates a real SQLite store with seeded test data for integration tests.
|
||||
func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) {
|
||||
func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -694,30 +694,28 @@ func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) {
|
||||
require.NoError(t, err)
|
||||
|
||||
permsMgr := permissions.NewManager(testStore)
|
||||
usersMgr := users.NewManager(testStore)
|
||||
settingsMgr := settings.NewManager(testStore, usersMgr, &noopExtraSettings{}, permsMgr, settings.IdpConfig{})
|
||||
|
||||
var storedEvents []activity.Activity
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
|
||||
storedEvents = append(storedEvents, activityID.(activity.Activity))
|
||||
},
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
|
||||
},
|
||||
}
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Hour)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil)
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
t.Cleanup(proxySrv.Close)
|
||||
|
||||
mgr := &managerImpl{
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
mgr := &Manager{
|
||||
store: testStore,
|
||||
accountManager: accountMgr,
|
||||
permissionsManager: permsMgr,
|
||||
settingsManager: settingsMgr,
|
||||
proxyGRPCServer: proxySrv,
|
||||
proxyController: proxyController,
|
||||
clusterDeriver: &testClusterDeriver{
|
||||
domains: []string{"test.netbird.io"},
|
||||
},
|
||||
@@ -791,7 +789,7 @@ func Test_validateExposePermission(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().GetAccountSettings(gomock.Any(), gomock.Any(), testAccountID).Return(nil, errors.New("store error"))
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
mgr := &Manager{store: mockStore}
|
||||
err := mgr.validateExposePermission(ctx, testAccountID, testPeerID)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "get account settings")
|
||||
@@ -804,7 +802,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
||||
t.Run("creates service with random domain", func(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
req := &reverseproxy.ExposeServiceRequest{
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
}
|
||||
@@ -819,7 +817,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
||||
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.Domain, persisted.Domain)
|
||||
assert.Equal(t, reverseproxy.SourceEphemeral, persisted.Source, "source should be ephemeral")
|
||||
assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral")
|
||||
assert.Equal(t, testPeerID, persisted.SourcePeer, "source peer should be set")
|
||||
assert.NotNil(t, persisted.Meta.LastRenewedAt, "last renewed should be set")
|
||||
})
|
||||
@@ -827,7 +825,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
||||
t.Run("creates service with custom domain", func(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
|
||||
req := &reverseproxy.ExposeServiceRequest{
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 80,
|
||||
Protocol: "http",
|
||||
Domain: "example.com",
|
||||
@@ -848,7 +846,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
||||
err = testStore.SaveAccountSettings(ctx, testAccountID, s)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &reverseproxy.ExposeServiceRequest{
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
}
|
||||
@@ -861,7 +859,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
||||
t.Run("validates request fields", func(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
|
||||
req := &reverseproxy.ExposeServiceRequest{
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 0,
|
||||
Protocol: "http",
|
||||
}
|
||||
@@ -875,67 +873,67 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
||||
func TestExposeServiceRequestValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req reverseproxy.ExposeServiceRequest
|
||||
req rpservice.ExposeServiceRequest
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid http request",
|
||||
req: reverseproxy.ExposeServiceRequest{Port: 8080, Protocol: "http"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"},
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "valid https request with pin",
|
||||
req: reverseproxy.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "port zero rejected",
|
||||
req: reverseproxy.ExposeServiceRequest{Port: 0, Protocol: "http"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"},
|
||||
wantErr: "port must be between 1 and 65535",
|
||||
},
|
||||
{
|
||||
name: "negative port rejected",
|
||||
req: reverseproxy.ExposeServiceRequest{Port: -1, Protocol: "http"},
|
||||
req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"},
|
||||
wantErr: "port must be between 1 and 65535",
|
||||
},
|
||||
{
|
||||
name: "port above 65535 rejected",
|
||||
req: reverseproxy.ExposeServiceRequest{Port: 65536, Protocol: "http"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"},
|
||||
wantErr: "port must be between 1 and 65535",
|
||||
},
|
||||
{
|
||||
name: "unsupported protocol",
|
||||
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
|
||||
wantErr: "unsupported protocol",
|
||||
},
|
||||
{
|
||||
name: "invalid pin format",
|
||||
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
|
||||
wantErr: "invalid pin",
|
||||
},
|
||||
{
|
||||
name: "pin too short",
|
||||
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
|
||||
wantErr: "invalid pin",
|
||||
},
|
||||
{
|
||||
name: "valid 6-digit pin",
|
||||
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "empty user group name",
|
||||
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
|
||||
wantErr: "user group name cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "invalid name prefix",
|
||||
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
|
||||
wantErr: "invalid name prefix",
|
||||
},
|
||||
{
|
||||
name: "valid name prefix",
|
||||
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
|
||||
wantErr: "",
|
||||
},
|
||||
}
|
||||
@@ -953,7 +951,7 @@ func TestExposeServiceRequestValidate(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("nil receiver", func(t *testing.T) {
|
||||
var req *reverseproxy.ExposeServiceRequest
|
||||
var req *rpservice.ExposeServiceRequest
|
||||
err := req.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "request cannot be nil")
|
||||
@@ -967,7 +965,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
// First create a service
|
||||
req := &reverseproxy.ExposeServiceRequest{
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
}
|
||||
@@ -986,7 +984,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
|
||||
t.Run("expire uses correct activity", func(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
|
||||
req := &reverseproxy.ExposeServiceRequest{
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
}
|
||||
@@ -1004,7 +1002,7 @@ func TestStopServiceFromPeer(t *testing.T) {
|
||||
t.Run("stops service by domain", func(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
req := &reverseproxy.ExposeServiceRequest{
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
}
|
||||
@@ -1023,7 +1021,7 @@ func TestDeleteService_UntracksEphemeralExpose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
@@ -1041,7 +1039,7 @@ func TestDeleteService_UntracksEphemeralExpose(t *testing.T) {
|
||||
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be untracked after API delete")
|
||||
|
||||
// A new expose should succeed (not blocked by stale tracking)
|
||||
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
|
||||
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 9090,
|
||||
Protocol: "http",
|
||||
})
|
||||
@@ -1053,7 +1051,7 @@ func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
|
||||
for i := range 3 {
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080 + i,
|
||||
Protocol: "http",
|
||||
})
|
||||
@@ -1074,7 +1072,7 @@ func TestRenewServiceFromPeer(t *testing.T) {
|
||||
t.Run("renews tracked expose", func(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
@@ -1129,25 +1127,32 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
|
||||
mockPerms := permissions.NewMockManager(ctrl)
|
||||
mockAcct := account.NewMockManager(ctrl)
|
||||
mockGRPC := &nbgrpc.ProxyServiceServer{}
|
||||
|
||||
mgr := &managerImpl{
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
t.Cleanup(proxySrv.Close)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
mgr := &Manager{
|
||||
store: sqlStore,
|
||||
permissionsManager: mockPerms,
|
||||
accountManager: mockAcct,
|
||||
proxyGRPCServer: mockGRPC,
|
||||
proxyController: proxyController,
|
||||
}
|
||||
|
||||
service := &reverseproxy.Service{
|
||||
service := &rpservice.Service{
|
||||
ID: "service-1",
|
||||
AccountID: accountID,
|
||||
Domain: "test.example.com",
|
||||
ProxyCluster: "cluster1",
|
||||
Enabled: true,
|
||||
Targets: []*reverseproxy.Target{
|
||||
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-1"},
|
||||
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-2"},
|
||||
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-3"},
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-1"},
|
||||
{AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-2"},
|
||||
{AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-3"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package reverseproxy
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
|
||||
@@ -29,15 +30,15 @@ const (
|
||||
Delete Operation = "delete"
|
||||
)
|
||||
|
||||
type ProxyStatus string
|
||||
type Status string
|
||||
|
||||
const (
|
||||
StatusPending ProxyStatus = "pending"
|
||||
StatusActive ProxyStatus = "active"
|
||||
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
|
||||
StatusCertificatePending ProxyStatus = "certificate_pending"
|
||||
StatusCertificateFailed ProxyStatus = "certificate_failed"
|
||||
StatusError ProxyStatus = "error"
|
||||
StatusPending Status = "pending"
|
||||
StatusActive Status = "active"
|
||||
StatusTunnelNotCreated Status = "tunnel_not_created"
|
||||
StatusCertificatePending Status = "certificate_pending"
|
||||
StatusCertificateFailed Status = "certificate_failed"
|
||||
StatusError Status = "error"
|
||||
|
||||
TargetTypePeer = "peer"
|
||||
TargetTypeHost = "host"
|
||||
@@ -111,14 +112,7 @@ func (a *AuthConfig) ClearSecrets() {
|
||||
}
|
||||
}
|
||||
|
||||
type OIDCValidationConfig struct {
|
||||
Issuer string
|
||||
Audiences []string
|
||||
KeysLocation string
|
||||
MaxTokenAgeSeconds int64
|
||||
}
|
||||
|
||||
type ServiceMeta struct {
|
||||
type Meta struct {
|
||||
CreatedAt time.Time
|
||||
CertificateIssuedAt *time.Time
|
||||
Status string
|
||||
@@ -136,7 +130,7 @@ type Service struct {
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
Auth AuthConfig `gorm:"serializer:json"`
|
||||
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
Source string `gorm:"default:'permanent'"`
|
||||
@@ -165,7 +159,7 @@ func NewService(accountID, name, domain, proxyCluster string, targets []*Target,
|
||||
// only be called during initial creation, not for updates.
|
||||
func (s *Service) InitNewRecord() {
|
||||
s.ID = xid.New().String()
|
||||
s.Meta = ServiceMeta{
|
||||
s.Meta = Meta{
|
||||
CreatedAt: time.Now(),
|
||||
Status: string(StatusPending),
|
||||
}
|
||||
@@ -239,7 +233,7 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
|
||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
||||
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
|
||||
for _, target := range s.Targets {
|
||||
if !target.Enabled {
|
||||
@@ -1,4 +1,4 @@
|
||||
package reverseproxy
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
@@ -109,7 +110,7 @@ func TestIsDefaultPort(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
||||
oidcConfig := OIDCValidationConfig{}
|
||||
oidcConfig := proxy.OIDCValidationConfig{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -202,7 +203,7 @@ func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
|
||||
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
|
||||
},
|
||||
}
|
||||
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
|
||||
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||
require.Len(t, pm.Path, 1)
|
||||
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
|
||||
}
|
||||
@@ -219,7 +220,7 @@ func TestToProtoMapping_OperationTypes(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.op), func(t *testing.T) {
|
||||
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
|
||||
pm := rp.ToProtoMapping(tt.op, "", proxy.OIDCValidationConfig{})
|
||||
assert.Equal(t, tt.want, pm.Type)
|
||||
})
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create certificate manager: %v", err)
|
||||
log.Fatalf("failed to create certificate service: %v", err)
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
||||
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
||||
@@ -152,10 +152,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
reverseProxyMgr := s.ReverseProxyManager()
|
||||
srv.SetReverseProxyManager(reverseProxyMgr)
|
||||
if reverseProxyMgr != nil {
|
||||
reverseProxyMgr.StartExposeReaper(context.Background())
|
||||
serviceMgr := s.ServiceManager()
|
||||
srv.SetReverseProxyManager(serviceMgr)
|
||||
if serviceMgr != nil {
|
||||
serviceMgr.StartExposeReaper(context.Background())
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||
|
||||
@@ -168,9 +168,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetProxyManager(s.ReverseProxyManager())
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
})
|
||||
return proxyService
|
||||
})
|
||||
@@ -193,7 +194,10 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
|
||||
|
||||
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
||||
return Create(s, func() *nbgrpc.OneTimeTokenStore {
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create proxy token store: %v", err)
|
||||
}
|
||||
log.Info("One-time token store initialized for proxy authentication")
|
||||
return tokenStore
|
||||
})
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
@@ -106,6 +108,16 @@ func (s *BaseServer) NetworkMapController() network_map.Controller {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ServiceProxyController() proxy.Controller {
|
||||
return Create(s, func() proxy.Controller {
|
||||
controller, err := proxymanager.NewGRPCController(s.ReverseProxyGRPCServer(), s.Metrics().GetMeter())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create service proxy controller: %v", err)
|
||||
}
|
||||
return controller
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
|
||||
return Create(s, func() *server.AccountRequestBuffer {
|
||||
return server.NewAccountRequestBuffer(context.Background(), s.Store())
|
||||
|
||||
@@ -8,9 +8,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
@@ -99,11 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
|
||||
return Create(s, func() account.Manager {
|
||||
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create account manager: %v", err)
|
||||
log.Fatalf("failed to create account service: %v", err)
|
||||
}
|
||||
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
accountManager.SetServiceManager(s.ReverseProxyManager())
|
||||
accountManager.SetServiceManager(s.ServiceManager())
|
||||
})
|
||||
|
||||
return accountManager
|
||||
@@ -114,28 +116,28 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return Create(s, func() idp.Manager {
|
||||
var idpManager idp.Manager
|
||||
var err error
|
||||
// Use embedded IdP manager if embedded Dex is configured and enabled.
|
||||
// Use embedded IdP service if embedded Dex is configured and enabled.
|
||||
// Legacy IdpManager won't be used anymore even if configured.
|
||||
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
||||
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create embedded IDP manager: %v", err)
|
||||
log.Fatalf("failed to create embedded IDP service: %v", err)
|
||||
}
|
||||
return idpManager
|
||||
}
|
||||
|
||||
// Fall back to external IdP manager
|
||||
// Fall back to external IdP service
|
||||
if s.Config.IdpManagerConfig != nil {
|
||||
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create IDP manager: %v", err)
|
||||
log.Fatalf("failed to create IDP service: %v", err)
|
||||
}
|
||||
}
|
||||
return idpManager
|
||||
})
|
||||
}
|
||||
|
||||
// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil
|
||||
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
|
||||
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
|
||||
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
|
||||
return nil
|
||||
@@ -162,7 +164,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
|
||||
|
||||
func (s *BaseServer) ResourcesManager() resources.Manager {
|
||||
return Create(s, func() resources.Manager {
|
||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
|
||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -190,15 +192,25 @@ func (s *BaseServer) RecordsManager() records.Manager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
|
||||
return Create(s, func() reverseproxy.Manager {
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.SettingsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
|
||||
func (s *BaseServer) ServiceManager() service.Manager {
|
||||
return Create(s, func() service.Manager {
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ProxyManager() proxy.Manager {
|
||||
return Create(s, func() proxy.Manager {
|
||||
manager, err := proxymanager.NewManager(s.Store(), s.Metrics().GetMeter())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create proxy manager: %v", err)
|
||||
}
|
||||
return manager
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
return Create(s, func() *manager.Manager {
|
||||
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
|
||||
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager())
|
||||
return &m
|
||||
})
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
|
||||
// Eagerly create the gRPC server so that all AfterInit hooks are registered
|
||||
// before we iterate them. Lazy creation after the loop would miss hooks
|
||||
// registered during GRPCServer() construction (e.g., SetProxyManager).
|
||||
// registered during GRPCServer() construction (e.g., SetServiceManager).
|
||||
s.GRPCServer()
|
||||
|
||||
for _, fn := range s.afterInit {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -39,7 +39,7 @@ func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage)
|
||||
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
|
||||
}
|
||||
|
||||
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &reverseproxy.ExposeServiceRequest{
|
||||
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{
|
||||
NamePrefix: exposeReq.NamePrefix,
|
||||
Port: int(exposeReq.Port),
|
||||
Protocol: exposeProtocolToString(exposeReq.Protocol),
|
||||
@@ -167,14 +167,14 @@ func (s *Server) authenticateExposePeer(ctx context.Context, peerKey wgtypes.Key
|
||||
return accountID, peer, nil
|
||||
}
|
||||
|
||||
func (s *Server) getReverseProxyManager() reverseproxy.Manager {
|
||||
func (s *Server) getReverseProxyManager() rpservice.Manager {
|
||||
s.reverseProxyMu.RLock()
|
||||
defer s.reverseProxyMu.RUnlock()
|
||||
return s.reverseProxyManager
|
||||
}
|
||||
|
||||
// SetReverseProxyManager sets the reverse proxy manager on the server.
|
||||
func (s *Server) SetReverseProxyManager(mgr reverseproxy.Manager) {
|
||||
func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) {
|
||||
s.reverseProxyMu.Lock()
|
||||
defer s.reverseProxyMu.Unlock()
|
||||
s.reverseProxyManager = mgr
|
||||
|
||||
@@ -1,28 +1,23 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/lib/v4/cache"
|
||||
"github.com/eko/gocache/lib/v4/store"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
)
|
||||
|
||||
// OneTimeTokenStore manages short-lived, single-use authentication tokens
|
||||
// for proxy-to-management RPC authentication. Tokens are generated when
|
||||
// a service is created and must be used exactly once by the proxy
|
||||
// to authenticate a subsequent RPC call.
|
||||
type OneTimeTokenStore struct {
|
||||
tokens map[string]*tokenMetadata
|
||||
mu sync.RWMutex
|
||||
cleanup *time.Ticker
|
||||
cleanupDone chan struct{}
|
||||
}
|
||||
|
||||
// tokenMetadata stores information about a one-time token
|
||||
type tokenMetadata struct {
|
||||
ServiceID string
|
||||
AccountID string
|
||||
@@ -30,20 +25,24 @@ type tokenMetadata struct {
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// NewOneTimeTokenStore creates a new token store with automatic cleanup
|
||||
// of expired tokens. The cleanupInterval determines how often expired
|
||||
// tokens are removed from memory.
|
||||
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
||||
store := &OneTimeTokenStore{
|
||||
tokens: make(map[string]*tokenMetadata),
|
||||
cleanup: time.NewTicker(cleanupInterval),
|
||||
cleanupDone: make(chan struct{}),
|
||||
// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
|
||||
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
|
||||
type OneTimeTokenStore struct {
|
||||
cache *cache.Cache[string]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewOneTimeTokenStore creates a token store with automatic backend selection
|
||||
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
|
||||
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cache store: %w", err)
|
||||
}
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go store.cleanupExpired()
|
||||
|
||||
return store
|
||||
return &OneTimeTokenStore{
|
||||
cache: cache.New[string](cacheStore),
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateToken creates a new cryptographically secure one-time token
|
||||
@@ -52,25 +51,30 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
|
||||
//
|
||||
// Returns the generated token string or an error if random generation fails.
|
||||
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
|
||||
// Generate 32 bytes (256 bits) of cryptographically secure random data
|
||||
randomBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random token: %w", err)
|
||||
}
|
||||
|
||||
// Encode as URL-safe base64 for easy transmission in gRPC
|
||||
token := base64.URLEncoding.EncodeToString(randomBytes)
|
||||
hashedToken := hashToken(token)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.tokens[token] = &tokenMetadata{
|
||||
metadata := &tokenMetadata{
|
||||
ServiceID: serviceID,
|
||||
AccountID: accountID,
|
||||
ExpiresAt: time.Now().Add(ttl),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
metadataJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to serialize token metadata: %w", err)
|
||||
}
|
||||
|
||||
if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil {
|
||||
return "", fmt.Errorf("failed to store token: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
|
||||
serviceID, accountID, ttl)
|
||||
|
||||
@@ -88,80 +92,45 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
|
||||
// - Account ID doesn't match
|
||||
// - Reverse proxy ID doesn't match
|
||||
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
hashedToken := hashToken(token)
|
||||
|
||||
metadata, exists := s.tokens[token]
|
||||
if !exists {
|
||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
|
||||
serviceID, accountID)
|
||||
metadataJSON, err := s.cache.Get(s.ctx, hashedToken)
|
||||
if err != nil {
|
||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
|
||||
return fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
// Check expiration
|
||||
metadata := &tokenMetadata{}
|
||||
if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil {
|
||||
log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||
return fmt.Errorf("invalid token metadata")
|
||||
}
|
||||
|
||||
if time.Now().After(metadata.ExpiresAt) {
|
||||
delete(s.tokens, token)
|
||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
|
||||
serviceID, accountID)
|
||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
// Validate account ID using constant-time comparison (prevents timing attacks)
|
||||
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
|
||||
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
|
||||
metadata.AccountID, accountID)
|
||||
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID)
|
||||
return fmt.Errorf("account ID mismatch")
|
||||
}
|
||||
|
||||
// Validate service ID using constant-time comparison
|
||||
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
|
||||
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
|
||||
metadata.ServiceID, serviceID)
|
||||
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID)
|
||||
return fmt.Errorf("service ID mismatch")
|
||||
}
|
||||
|
||||
// Delete token immediately to enforce single-use
|
||||
delete(s.tokens, token)
|
||||
if err := s.cache.Delete(s.ctx, hashedToken); err != nil {
|
||||
log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||
}
|
||||
|
||||
log.Infof("Token validated and consumed for proxy %s in account %s",
|
||||
serviceID, accountID)
|
||||
log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupExpired removes expired tokens in the background to prevent memory leaks
|
||||
func (s *OneTimeTokenStore) cleanupExpired() {
|
||||
for {
|
||||
select {
|
||||
case <-s.cleanup.C:
|
||||
s.mu.Lock()
|
||||
now := time.Now()
|
||||
removed := 0
|
||||
for token, metadata := range s.tokens {
|
||||
if now.After(metadata.ExpiresAt) {
|
||||
delete(s.tokens, token)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
if removed > 0 {
|
||||
log.Debugf("Cleaned up %d expired one-time tokens", removed)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
case <-s.cleanupDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup goroutine and releases resources
|
||||
func (s *OneTimeTokenStore) Close() {
|
||||
s.cleanup.Stop()
|
||||
close(s.cleanupDone)
|
||||
}
|
||||
|
||||
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
|
||||
func (s *OneTimeTokenStore) GetTokenCount() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.tokens)
|
||||
func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
@@ -24,8 +24,9 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
@@ -58,14 +59,17 @@ type ProxyServiceServer struct {
|
||||
// Map of connected proxies: proxy_id -> proxy connection
|
||||
connectedProxies sync.Map
|
||||
|
||||
// Map of cluster address -> set of proxy IDs
|
||||
clusterProxies sync.Map
|
||||
|
||||
// Manager for access logs
|
||||
accessLogManager accesslogs.Manager
|
||||
|
||||
// Manager for reverse proxy operations
|
||||
reverseProxyManager reverseproxy.Manager
|
||||
serviceManager rpservice.Manager
|
||||
|
||||
// ProxyController for service updates and cluster management
|
||||
proxyController proxy.Controller
|
||||
|
||||
// Manager for proxy connections
|
||||
proxyManager proxy.Manager
|
||||
|
||||
// Manager for peers
|
||||
peersManager peers.Manager
|
||||
@@ -104,7 +108,7 @@ type proxyConnection struct {
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer {
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
@@ -112,9 +116,11 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
tokenStore: tokenStore,
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
pkceCleanupCancel: cancel,
|
||||
}
|
||||
go s.cleanupPKCEVerifiers(ctx)
|
||||
go s.cleanupStaleProxies(ctx)
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -138,13 +144,33 @@ func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
|
||||
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops background goroutines.
|
||||
func (s *ProxyServiceServer) Close() {
|
||||
s.pkceCleanupCancel()
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) {
|
||||
s.reverseProxyManager = manager
|
||||
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
|
||||
s.serviceManager = manager
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) {
|
||||
s.proxyController = proxyController
|
||||
}
|
||||
|
||||
// GetMappingUpdate handles the control stream with proxy clients
|
||||
@@ -179,7 +205,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
s.addToCluster(conn.address, proxyID)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
// Register proxy in database
|
||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"address": proxyAddress,
|
||||
@@ -187,8 +221,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
s.removeFromCluster(conn.address, proxyID)
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
log.Infof("Proxy %s disconnected", proxyID)
|
||||
}()
|
||||
@@ -200,6 +241,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
// Start heartbeat goroutine
|
||||
go s.heartbeat(connCtx, proxyID)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||
@@ -208,10 +252,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
|
||||
// Only services matching the proxy's cluster address are sent.
|
||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get services from store: %w", err)
|
||||
}
|
||||
@@ -220,7 +281,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
|
||||
var filtered []*reverseproxy.Service
|
||||
var filtered []*rpservice.Service
|
||||
for _, service := range services {
|
||||
if !service.Enabled {
|
||||
continue
|
||||
@@ -255,7 +316,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
service.ToProtoMapping(
|
||||
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy.
|
||||
rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
|
||||
token,
|
||||
s.GetOIDCValidationConfig(),
|
||||
),
|
||||
@@ -389,61 +450,47 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
|
||||
return urls
|
||||
}
|
||||
|
||||
// addToCluster registers a proxy in a cluster.
|
||||
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
|
||||
if clusterAddr == "" {
|
||||
return
|
||||
}
|
||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
|
||||
}
|
||||
|
||||
// removeFromCluster removes a proxy from a cluster.
|
||||
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
|
||||
if clusterAddr == "" {
|
||||
return
|
||||
}
|
||||
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
|
||||
proxySet.(*sync.Map).Delete(proxyID)
|
||||
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
|
||||
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
|
||||
// For create/update operations a unique one-time auth token is generated per
|
||||
// proxy so that every replica can independently authenticate with management.
|
||||
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.GetMappingUpdateResponse, clusterAddr string) {
|
||||
func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) {
|
||||
updateResponse := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{update},
|
||||
}
|
||||
|
||||
if clusterAddr == "" {
|
||||
s.SendServiceUpdate(update)
|
||||
s.SendServiceUpdate(updateResponse)
|
||||
return
|
||||
}
|
||||
|
||||
proxySet, ok := s.clusterProxies.Load(clusterAddr)
|
||||
if !ok {
|
||||
log.Debugf("No proxies connected for cluster %s", clusterAddr)
|
||||
if s.proxyController == nil {
|
||||
log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr)
|
||||
return
|
||||
}
|
||||
|
||||
proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr)
|
||||
if len(proxyIDs) == 0 {
|
||||
log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Sending service update to cluster %s", clusterAddr)
|
||||
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
||||
proxyID := key.(string)
|
||||
for _, proxyID := range proxyIDs {
|
||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||
conn := connVal.(*proxyConnection)
|
||||
msg := s.perProxyMessage(update, proxyID)
|
||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||
if msg == nil {
|
||||
return true
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.Debugf("Sent service update to proxy %s in cluster %s", proxyID, clusterAddr)
|
||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
@@ -490,35 +537,8 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableClusters returns information about all connected proxy clusters.
|
||||
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
|
||||
clusterCounts := make(map[string]int)
|
||||
s.clusterProxies.Range(func(key, value interface{}) bool {
|
||||
clusterAddr := key.(string)
|
||||
proxySet := value.(*sync.Map)
|
||||
count := 0
|
||||
proxySet.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
clusterCounts[clusterAddr] = count
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
clusters := make([]ClusterInfo, 0, len(clusterCounts))
|
||||
for addr, count := range clusterCounts {
|
||||
clusters = append(clusters, ClusterInfo{
|
||||
Address: addr,
|
||||
ConnectedProxies: count,
|
||||
})
|
||||
}
|
||||
return clusters
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
|
||||
@@ -537,7 +557,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) {
|
||||
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *rpservice.Service) (bool, string, proxyauth.Method) {
|
||||
switch v := req.GetRequest().(type) {
|
||||
case *proto.AuthenticateRequest_Pin:
|
||||
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
|
||||
@@ -548,7 +568,7 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) {
|
||||
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *rpservice.PINAuthConfig) (bool, string, proxyauth.Method) {
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
|
||||
return false, "", ""
|
||||
@@ -562,7 +582,7 @@ func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID stri
|
||||
return true, "pin-user", proxyauth.MethodPIN
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) {
|
||||
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *rpservice.PasswordAuthConfig) (bool, string, proxyauth.Method) {
|
||||
if auth == nil || !auth.Enabled {
|
||||
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
|
||||
return false, "", ""
|
||||
@@ -584,7 +604,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) {
|
||||
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
|
||||
if !authenticated || service.SessionPrivateKey == "" {
|
||||
return "", nil
|
||||
}
|
||||
@@ -624,7 +644,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
}
|
||||
|
||||
if certificateIssued {
|
||||
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
|
||||
if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
|
||||
return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err)
|
||||
}
|
||||
@@ -636,7 +656,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
|
||||
internalStatus := protoStatusToInternal(protoStatus)
|
||||
|
||||
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
|
||||
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
|
||||
log.WithContext(ctx).WithError(err).Error("failed to update service status")
|
||||
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
|
||||
}
|
||||
@@ -651,22 +671,22 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
|
||||
}
|
||||
|
||||
// protoStatusToInternal maps proto status to internal status
|
||||
func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStatus {
|
||||
func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
|
||||
switch protoStatus {
|
||||
case proto.ProxyStatus_PROXY_STATUS_PENDING:
|
||||
return reverseproxy.StatusPending
|
||||
return rpservice.StatusPending
|
||||
case proto.ProxyStatus_PROXY_STATUS_ACTIVE:
|
||||
return reverseproxy.StatusActive
|
||||
return rpservice.StatusActive
|
||||
case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED:
|
||||
return reverseproxy.StatusTunnelNotCreated
|
||||
return rpservice.StatusTunnelNotCreated
|
||||
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING:
|
||||
return reverseproxy.StatusCertificatePending
|
||||
return rpservice.StatusCertificatePending
|
||||
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED:
|
||||
return reverseproxy.StatusCertificateFailed
|
||||
return rpservice.StatusCertificateFailed
|
||||
case proto.ProxyStatus_PROXY_STATUS_ERROR:
|
||||
return reverseproxy.StatusError
|
||||
return rpservice.StatusError
|
||||
default:
|
||||
return reverseproxy.StatusError
|
||||
return rpservice.StatusError
|
||||
}
|
||||
}
|
||||
|
||||
@@ -731,7 +751,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
||||
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
|
||||
}
|
||||
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
|
||||
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId())
|
||||
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account services: %v", err)
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
|
||||
@@ -794,8 +814,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig {
|
||||
|
||||
// GetOIDCValidationConfig returns the OIDC configuration for token validation
|
||||
// in the format needed by ToProtoMapping.
|
||||
func (s *ProxyServiceServer) GetOIDCValidationConfig() reverseproxy.OIDCValidationConfig {
|
||||
return reverseproxy.OIDCValidationConfig{
|
||||
func (s *ProxyServiceServer) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
|
||||
return proxy.OIDCValidationConfig{
|
||||
Issuer: s.oidcConfig.Issuer,
|
||||
Audiences: []string{s.oidcConfig.Audience},
|
||||
KeysLocation: s.oidcConfig.KeysLocation,
|
||||
@@ -854,12 +874,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||
// Find the service by domain to get its signing key
|
||||
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
|
||||
var service *reverseproxy.Service
|
||||
var service *rpservice.Service
|
||||
for _, svc := range services {
|
||||
if svc.Domain == domain {
|
||||
service = svc
|
||||
@@ -925,8 +945,8 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
|
||||
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
||||
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID)
|
||||
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
|
||||
services, err := s.serviceManager.GetAccountServices(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account services: %w", err)
|
||||
}
|
||||
@@ -1047,8 +1067,8 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) {
|
||||
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
|
||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get services: %w", err)
|
||||
}
|
||||
@@ -1062,7 +1082,7 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri
|
||||
return nil, fmt.Errorf("service not found for domain: %s", domain)
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error {
|
||||
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
||||
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type mockReverseProxyManager struct {
|
||||
proxiesByAccount map[string][]*reverseproxy.Service
|
||||
proxiesByAccount map[string][]*service.Service
|
||||
err error
|
||||
}
|
||||
|
||||
@@ -21,31 +21,31 @@ func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, account
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
return m.proxiesByAccount[accountID], nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||
return []*reverseproxy.Service{}, nil
|
||||
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
return []*service.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
|
||||
return &reverseproxy.Service{}, nil
|
||||
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*service.Service, error) {
|
||||
return &service.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
return &reverseproxy.Service{}, nil
|
||||
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
|
||||
return &service.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
return &reverseproxy.Service{}, nil
|
||||
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
|
||||
return &service.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
|
||||
@@ -56,7 +56,7 @@ func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, ac
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
|
||||
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status service.Status) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -68,16 +68,16 @@ func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
|
||||
return &reverseproxy.Service{}, nil
|
||||
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*service.Service, error) {
|
||||
return &service.Service{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
|
||||
return &reverseproxy.ExposeServiceResponse{}, nil
|
||||
func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||
return &service.ExposeServiceResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||
@@ -111,7 +111,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name string
|
||||
domain string
|
||||
userID string
|
||||
proxiesByAccount map[string][]*reverseproxy.Service
|
||||
proxiesByAccount map[string][]*service.Service
|
||||
users map[string]*types.User
|
||||
proxyErr error
|
||||
userErr error
|
||||
@@ -122,7 +122,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "user not found",
|
||||
domain: "app.example.com",
|
||||
userID: "unknown-user",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||
},
|
||||
users: map[string]*types.User{},
|
||||
@@ -133,7 +133,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "proxy not found in user's account",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
||||
proxiesByAccount: map[string][]*service.Service{},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1"},
|
||||
},
|
||||
@@ -144,7 +144,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "proxy exists in different account - not accessible",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
|
||||
},
|
||||
users: map[string]*types.User{
|
||||
@@ -157,8 +157,8 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "no bearer auth configured - same account allows access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: service.AuthConfig{}}},
|
||||
},
|
||||
users: map[string]*types.User{
|
||||
"user1": {Id: "user1", AccountID: "account1"},
|
||||
@@ -169,12 +169,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "bearer auth disabled - same account allows access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{Enabled: false},
|
||||
},
|
||||
}},
|
||||
},
|
||||
@@ -187,12 +187,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "bearer auth enabled but no groups configured - same account allows access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{},
|
||||
},
|
||||
@@ -208,12 +208,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "user not in allowed groups",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"group1", "group2"},
|
||||
},
|
||||
@@ -230,12 +230,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "user in one of the allowed groups - allow access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"group1", "group2"},
|
||||
},
|
||||
@@ -251,12 +251,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "user in all allowed groups - allow access",
|
||||
domain: "app.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
"account1": {{
|
||||
Domain: "app.example.com",
|
||||
AccountID: "account1",
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"group1", "group2"},
|
||||
},
|
||||
@@ -284,10 +284,10 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
name: "multiple proxies in account - finds correct one",
|
||||
domain: "app2.example.com",
|
||||
userID: "user1",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
"account1": {
|
||||
{Domain: "app1.example.com", AccountID: "account1"},
|
||||
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
|
||||
{Domain: "app2.example.com", AccountID: "account1", Auth: service.AuthConfig{}},
|
||||
{Domain: "app3.example.com", AccountID: "account1"},
|
||||
},
|
||||
},
|
||||
@@ -301,7 +301,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &ProxyServiceServer{
|
||||
reverseProxyManager: &mockReverseProxyManager{
|
||||
serviceManager: &mockReverseProxyManager{
|
||||
proxiesByAccount: tt.proxiesByAccount,
|
||||
err: tt.proxyErr,
|
||||
},
|
||||
@@ -328,7 +328,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
name string
|
||||
accountID string
|
||||
domain string
|
||||
proxiesByAccount map[string][]*reverseproxy.Service
|
||||
proxiesByAccount map[string][]*service.Service
|
||||
err error
|
||||
expectProxy bool
|
||||
expectErr bool
|
||||
@@ -337,7 +337,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
name: "proxy found",
|
||||
accountID: "account1",
|
||||
domain: "app.example.com",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
"account1": {
|
||||
{Domain: "other.example.com", AccountID: "account1"},
|
||||
{Domain: "app.example.com", AccountID: "account1"},
|
||||
@@ -350,7 +350,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
name: "proxy not found in account",
|
||||
accountID: "account1",
|
||||
domain: "unknown.example.com",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{
|
||||
proxiesByAccount: map[string][]*service.Service{
|
||||
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
|
||||
},
|
||||
expectProxy: false,
|
||||
@@ -360,7 +360,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
name: "empty proxy list for account",
|
||||
accountID: "account1",
|
||||
domain: "app.example.com",
|
||||
proxiesByAccount: map[string][]*reverseproxy.Service{},
|
||||
proxiesByAccount: map[string][]*service.Service{},
|
||||
expectProxy: false,
|
||||
expectErr: true,
|
||||
},
|
||||
@@ -378,7 +378,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := &ProxyServiceServer{
|
||||
reverseProxyManager: &mockReverseProxyManager{
|
||||
serviceManager: &mockReverseProxyManager{
|
||||
proxiesByAccount: tt.proxiesByAccount,
|
||||
err: tt.err,
|
||||
},
|
||||
|
||||
@@ -1,19 +1,73 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"sync"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type testProxyController struct {
|
||||
mu sync.Mutex
|
||||
clusterProxies map[string]map[string]struct{}
|
||||
}
|
||||
|
||||
func newTestProxyController() *testProxyController {
|
||||
return &testProxyController{
|
||||
clusterProxies: make(map[string]map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
|
||||
}
|
||||
|
||||
func (c *testProxyController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
|
||||
return proxy.OIDCValidationConfig{}
|
||||
}
|
||||
|
||||
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, clusterAddr, proxyID string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if _, ok := c.clusterProxies[clusterAddr]; !ok {
|
||||
c.clusterProxies[clusterAddr] = make(map[string]struct{})
|
||||
}
|
||||
c.clusterProxies[clusterAddr][proxyID] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clusterAddr, proxyID string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if proxies, ok := c.clusterProxies[clusterAddr]; ok {
|
||||
delete(proxies, proxyID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
proxies, ok := c.clusterProxies[clusterAddr]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(proxies))
|
||||
for id := range proxies {
|
||||
result = append(result, id)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
||||
// and returns the channel where messages will be received.
|
||||
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
|
||||
@@ -25,8 +79,7 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
|
||||
}
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
|
||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||
_ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID)
|
||||
|
||||
return ch
|
||||
}
|
||||
@@ -41,12 +94,13 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||
defer tokenStore.Close()
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
const numProxies = 3
|
||||
@@ -67,11 +121,7 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
update := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{mapping},
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(update, cluster)
|
||||
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
|
||||
|
||||
tokens := make([]string, numProxies)
|
||||
for i, ch := range channels {
|
||||
@@ -101,12 +151,13 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||
defer tokenStore.Close()
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
||||
@@ -119,11 +170,7 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
Domain: "test.example.com",
|
||||
}
|
||||
|
||||
update := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{mapping},
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(update, cluster)
|
||||
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
|
||||
|
||||
resp1 := drainChannel(ch1)
|
||||
resp2 := drainChannel(ch2)
|
||||
@@ -135,18 +182,16 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
// Delete operations should not generate tokens
|
||||
assert.Empty(t, resp1.Mapping[0].AuthToken)
|
||||
assert.Empty(t, resp2.Mapping[0].AuthToken)
|
||||
|
||||
// No tokens should have been created
|
||||
assert.Equal(t, 0, tokenStore.GetTokenCount())
|
||||
}
|
||||
|
||||
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||
tokenStore := NewOneTimeTokenStore(time.Hour)
|
||||
defer tokenStore.Close()
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
||||
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
@@ -82,7 +82,7 @@ type Server struct {
|
||||
syncLimEnabled bool
|
||||
syncLim int32
|
||||
|
||||
reverseProxyManager reverseproxy.Manager
|
||||
reverseProxyManager rpservice.Manager
|
||||
reverseProxyMu sync.RWMutex
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -34,11 +34,15 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyManager := &testValidateSessionProxyManager{store: testStore}
|
||||
serviceManager := &testValidateSessionServiceManager{store: testStore}
|
||||
usersManager := &testValidateSessionUsersManager{store: testStore}
|
||||
proxyManager := &testValidateSessionProxyManager{}
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
|
||||
proxyService.SetProxyManager(proxyManager)
|
||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
||||
proxyService.SetServiceManager(serviceManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
|
||||
@@ -54,7 +58,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
|
||||
|
||||
pubKey, privKey := generateSessionKeyPair(t)
|
||||
|
||||
testProxy := &reverseproxy.Service{
|
||||
testProxy := &service.Service{
|
||||
ID: "testProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Test Proxy",
|
||||
@@ -62,15 +66,15 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
|
||||
Enabled: true,
|
||||
SessionPrivateKey: privKey,
|
||||
SessionPublicKey: pubKey,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
||||
|
||||
restrictedProxy := &reverseproxy.Service{
|
||||
restrictedProxy := &service.Service{
|
||||
ID: "restrictedProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Restricted Proxy",
|
||||
@@ -78,8 +82,8 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
|
||||
Enabled: true,
|
||||
SessionPrivateKey: privKey,
|
||||
SessionPublicKey: pubKey,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"allowedGroupId"},
|
||||
},
|
||||
@@ -239,79 +243,101 @@ func TestValidateSession_MissingToken(t *testing.T) {
|
||||
assert.Contains(t, resp.DeniedReason, "missing")
|
||||
}
|
||||
|
||||
type testValidateSessionProxyManager struct {
|
||||
type testValidateSessionServiceManager struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
||||
func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
||||
func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
|
||||
func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) DeleteAllServices(_ context.Context, _, _ string) error {
|
||||
func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||
func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
||||
func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
||||
func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
|
||||
func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
||||
func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
|
||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||
func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
|
||||
func (m *testValidateSessionServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||
func (m *testValidateSessionServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||
func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) StartExposeReaper(_ context.Context) {}
|
||||
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
type testValidateSessionProxyManager struct{}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testValidateSessionUsersManager struct {
|
||||
store store.Store
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
@@ -85,7 +85,7 @@ type DefaultAccountManager struct {
|
||||
|
||||
proxyController port_forwarding.Controller
|
||||
settingsManager settings.Manager
|
||||
reverseProxyManager reverseproxy.Manager
|
||||
serviceManager service.Manager
|
||||
|
||||
// config contains the management server configuration
|
||||
config *nbconfig.Config
|
||||
@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
|
||||
|
||||
var _ account.Manager = (*DefaultAccountManager)(nil)
|
||||
|
||||
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
||||
am.reverseProxyManager = serviceManager
|
||||
func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) {
|
||||
am.serviceManager = serviceManager
|
||||
}
|
||||
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
@@ -395,7 +395,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
|
||||
}
|
||||
if reloadReverseProxy {
|
||||
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||
if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
@@ -730,7 +730,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
err = am.reverseProxyManager.DeleteAllServices(ctx, accountID, userID)
|
||||
err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -142,5 +142,5 @@ type Manager interface {
|
||||
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||
SetServiceManager(serviceManager reverseproxy.Manager)
|
||||
SetServiceManager(serviceManager service.Manager)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
dns "github.com/netbirdio/netbird/dns"
|
||||
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
activity "github.com/netbirdio/netbird/management/server/activity"
|
||||
idp "github.com/netbirdio/netbird/management/server/idp"
|
||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -1494,7 +1494,7 @@ func (mr *MockManagerMockRecorder) SaveUser(ctx, accountID, initiatorUserID, upd
|
||||
}
|
||||
|
||||
// SetServiceManager mocks base method.
|
||||
func (m *MockManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
||||
func (m *MockManager) SetServiceManager(serviceManager service.Manager) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetServiceManager", serviceManager)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -27,8 +28,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
@@ -1803,12 +1806,12 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Address: "172.12.6.1/24",
|
||||
},
|
||||
},
|
||||
Services: []*reverseproxy.Service{
|
||||
Services: []*service.Service{
|
||||
{
|
||||
ID: "service1",
|
||||
Name: "test-service",
|
||||
AccountID: "account1",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
Targets: []*service.Target{},
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
@@ -3113,6 +3116,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
proxyManager := proxy.NewMockManager(ctrl)
|
||||
proxyManager.EXPECT().
|
||||
CleanupStale(gomock.Any(), gomock.Any()).
|
||||
Return(nil).
|
||||
AnyTimes()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
@@ -3123,8 +3132,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil)
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, settingsMockManager, proxyGrpcServer, nil))
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
|
||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
|
||||
|
||||
return manager, updateManager, nil
|
||||
}
|
||||
|
||||
@@ -766,7 +766,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving group linked to network router", func(t *testing.T) {
|
||||
permissionsManager := permissions.NewManager(manager.Store)
|
||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
|
||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
@@ -73,7 +73,7 @@ const (
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -173,8 +173,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router)
|
||||
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
|
||||
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
|
||||
}
|
||||
|
||||
// Register OAuth callback handler for proxy authentication
|
||||
|
||||
@@ -18,8 +18,8 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -190,7 +190,8 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
|
||||
oidcServer := newFakeOIDCServer()
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
usersManager := users.NewManager(testStore)
|
||||
|
||||
@@ -208,9 +209,10 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
proxyService.SetProxyManager(&testServiceManager{store: testStore})
|
||||
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
||||
|
||||
handler := NewAuthCallbackHandler(proxyService, nil)
|
||||
|
||||
@@ -239,12 +241,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
||||
pubKey := base64.StdEncoding.EncodeToString(pub)
|
||||
privKey := base64.StdEncoding.EncodeToString(priv)
|
||||
|
||||
testProxy := &reverseproxy.Service{
|
||||
testProxy := &service.Service{
|
||||
ID: "testProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Test Proxy",
|
||||
Domain: "test-proxy.example.com",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Targets: []*service.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
@@ -254,8 +256,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
||||
Enabled: true,
|
||||
}},
|
||||
Enabled: true,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"allowedGroupId"},
|
||||
},
|
||||
@@ -265,12 +267,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
||||
}
|
||||
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
||||
|
||||
restrictedProxy := &reverseproxy.Service{
|
||||
restrictedProxy := &service.Service{
|
||||
ID: "restrictedProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Restricted Proxy",
|
||||
Domain: "restricted-proxy.example.com",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Targets: []*service.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
@@ -280,8 +282,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
||||
Enabled: true,
|
||||
}},
|
||||
Enabled: true,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"restrictedGroupId"},
|
||||
},
|
||||
@@ -291,12 +293,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
||||
}
|
||||
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
|
||||
|
||||
noAuthProxy := &reverseproxy.Service{
|
||||
noAuthProxy := &service.Service{
|
||||
ID: "noAuthProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "No Auth Proxy",
|
||||
Domain: "no-auth-proxy.example.com",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Targets: []*service.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
@@ -306,8 +308,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
||||
Enabled: true,
|
||||
}},
|
||||
Enabled: true,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
@@ -361,19 +363,19 @@ func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, u
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -385,7 +387,7 @@ func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ stri
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
||||
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -397,15 +399,15 @@ func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
|
||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
@@ -413,7 +415,7 @@ func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ stri
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
|
||||
func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,10 +9,13 @@ import (
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
@@ -91,12 +94,24 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
}
|
||||
|
||||
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
|
||||
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager)
|
||||
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager)
|
||||
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, settingsManager, proxyServiceServer, domainManager)
|
||||
proxyServiceServer.SetProxyManager(reverseProxyManager)
|
||||
am.SetServiceManager(reverseProxyManager)
|
||||
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy token store: %v", err)
|
||||
}
|
||||
noopMeter := noop.NewMeterProvider().Meter("")
|
||||
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy controller: %v", err)
|
||||
}
|
||||
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager)
|
||||
proxyServiceServer.SetServiceManager(serviceManager)
|
||||
am.SetServiceManager(serviceManager)
|
||||
|
||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
@@ -114,7 +129,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -358,12 +358,12 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
}
|
||||
servicesTargets += len(service.Targets)
|
||||
|
||||
switch reverseproxy.ProxyStatus(service.Meta.Status) {
|
||||
case reverseproxy.StatusActive:
|
||||
switch rpservice.Status(service.Meta.Status) {
|
||||
case rpservice.StatusActive:
|
||||
servicesStatusActive++
|
||||
case reverseproxy.StatusPending:
|
||||
case rpservice.StatusPending:
|
||||
servicesStatusPending++
|
||||
case reverseproxy.StatusError, reverseproxy.StatusCertificateFailed, reverseproxy.StatusTunnelNotCreated:
|
||||
case rpservice.StatusError, rpservice.StatusCertificateFailed, rpservice.StatusTunnelNotCreated:
|
||||
servicesStatusError++
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -116,29 +116,29 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
},
|
||||
},
|
||||
},
|
||||
Services: []*reverseproxy.Service{
|
||||
Services: []*rpservice.Service{
|
||||
{
|
||||
ID: "svc1",
|
||||
Enabled: true,
|
||||
Targets: []*reverseproxy.Target{
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: "peer"},
|
||||
{TargetType: "host"},
|
||||
},
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{Enabled: true},
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
|
||||
},
|
||||
Meta: reverseproxy.ServiceMeta{Status: string(reverseproxy.StatusActive)},
|
||||
Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
|
||||
},
|
||||
{
|
||||
ID: "svc2",
|
||||
Enabled: false,
|
||||
Targets: []*reverseproxy.Target{
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: "domain"},
|
||||
},
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: true},
|
||||
Auth: rpservice.AuthConfig{
|
||||
BearerAuth: &rpservice.BearerAuthConfig{Enabled: true},
|
||||
},
|
||||
Meta: reverseproxy.ServiceMeta{Status: string(reverseproxy.StatusPending)},
|
||||
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
@@ -148,7 +148,7 @@ type MockAccountManager struct {
|
||||
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
||||
func (am *MockAccountManager) SetServiceManager(serviceManager service.Manager) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -37,19 +37,19 @@ type managerImpl struct {
|
||||
permissionsManager permissions.Manager
|
||||
groupsManager groups.Manager
|
||||
accountManager account.Manager
|
||||
reverseProxyManager reverseproxy.Manager
|
||||
serviceManager service.Manager
|
||||
}
|
||||
|
||||
type mockManager struct {
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager {
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
permissionsManager: permissionsManager,
|
||||
groupsManager: groupsManager,
|
||||
accountManager: accountManager,
|
||||
reverseProxyManager: reverseproxyManager,
|
||||
serviceManager: reverseproxyManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
|
||||
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
|
||||
go func() {
|
||||
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
|
||||
err := m.serviceManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
|
||||
}
|
||||
@@ -322,7 +322,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
|
||||
serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if resource is used by service: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -31,8 +31,8 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
|
||||
require.NoError(t, err)
|
||||
@@ -54,8 +54,8 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
|
||||
require.Error(t, err)
|
||||
@@ -76,8 +76,8 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
|
||||
require.NoError(t, err)
|
||||
@@ -98,8 +98,8 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
|
||||
require.Error(t, err)
|
||||
@@ -123,8 +123,8 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
|
||||
require.NoError(t, err)
|
||||
@@ -147,8 +147,8 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
|
||||
require.Error(t, err)
|
||||
@@ -176,9 +176,9 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||
require.NoError(t, err)
|
||||
@@ -205,8 +205,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -234,8 +234,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -262,8 +262,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -294,9 +294,9 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||
require.NoError(t, err)
|
||||
@@ -329,8 +329,8 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -361,8 +361,8 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -392,8 +392,8 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -416,9 +416,9 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
reverseProxyManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
serviceManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
|
||||
require.NoError(t, err)
|
||||
@@ -440,8 +440,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -493,7 +493,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
var settings *types.Settings
|
||||
var eventsToStore []func()
|
||||
|
||||
serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID)
|
||||
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if resource is used by service: %w", err)
|
||||
}
|
||||
|
||||
@@ -355,6 +355,7 @@ func (p *Peer) FromAPITemporaryAccessRequest(a *api.PeerTemporaryAccessRequest)
|
||||
Hostname: a.Name,
|
||||
GoOS: "js",
|
||||
OS: "js",
|
||||
KernelVersion: "wasm",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,9 +28,10 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -131,8 +132,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -2075,7 +2076,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
|
||||
return checks, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
|
||||
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
|
||||
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
|
||||
pass_host_header, rewrite_redirects, session_private_key, session_public_key
|
||||
@@ -2090,8 +2091,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
return nil, err
|
||||
}
|
||||
|
||||
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) {
|
||||
var s reverseproxy.Service
|
||||
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
|
||||
var s rpservice.Service
|
||||
var auth []byte
|
||||
var createdAt, certIssuedAt sql.NullTime
|
||||
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
|
||||
@@ -2121,7 +2122,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
}
|
||||
}
|
||||
|
||||
s.Meta = reverseproxy.ServiceMeta{}
|
||||
s.Meta = rpservice.Meta{}
|
||||
if createdAt.Valid {
|
||||
s.Meta.CreatedAt = createdAt.Time
|
||||
}
|
||||
@@ -2142,7 +2143,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
s.SessionPublicKey = sessionPublicKey.String
|
||||
}
|
||||
|
||||
s.Targets = []*reverseproxy.Target{}
|
||||
s.Targets = []*rpservice.Target{}
|
||||
return &s, nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -2154,7 +2155,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
}
|
||||
|
||||
serviceIDs := make([]string, len(services))
|
||||
serviceMap := make(map[string]*reverseproxy.Service)
|
||||
serviceMap := make(map[string]*rpservice.Service)
|
||||
for i, s := range services {
|
||||
serviceIDs[i] = s.ID
|
||||
serviceMap[s.ID] = s
|
||||
@@ -2165,8 +2166,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) {
|
||||
var t reverseproxy.Target
|
||||
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*rpservice.Target, error) {
|
||||
var t rpservice.Target
|
||||
var path sql.NullString
|
||||
err := row.Scan(
|
||||
&t.ID,
|
||||
@@ -4852,7 +4853,7 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (s *SqlStore) CreateService(ctx context.Context, service *rpservice.Service) error {
|
||||
serviceCopy := service.Copy()
|
||||
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt service data: %w", err)
|
||||
@@ -4866,16 +4867,19 @@ func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Serv
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service) error {
|
||||
serviceCopy := service.Copy()
|
||||
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt service data: %w", err)
|
||||
}
|
||||
|
||||
// Create target type instance outside transaction to avoid variable shadowing
|
||||
targetType := &rpservice.Target{}
|
||||
|
||||
// Use a transaction to ensure atomic updates of the service and its targets
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Delete existing targets
|
||||
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
|
||||
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(targetType).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4896,7 +4900,7 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Serv
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
|
||||
result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID)
|
||||
result := s.db.Delete(&rpservice.Service{}, accountAndIDQueryCondition, accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete service from store")
|
||||
@@ -4910,7 +4914,7 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error {
|
||||
result := s.db.Delete(&reverseproxy.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID)
|
||||
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete target from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete target from store")
|
||||
@@ -4924,7 +4928,7 @@ func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error {
|
||||
result := s.db.Delete(&reverseproxy.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID)
|
||||
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete targets from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete targets from store")
|
||||
@@ -4934,8 +4938,8 @@ func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, s
|
||||
}
|
||||
|
||||
// GetTargetsByServiceID retrieves all targets for a given service
|
||||
func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*reverseproxy.Target, error) {
|
||||
var targets []*reverseproxy.Target
|
||||
func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error) {
|
||||
var targets []*rpservice.Target
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
@@ -4949,13 +4953,13 @@ func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength Locki
|
||||
return targets, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var service *reverseproxy.Service
|
||||
var service *rpservice.Service
|
||||
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -4973,30 +4977,8 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var serviceList []*reverseproxy.Service
|
||||
result := tx.Find(&serviceList, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get services from store")
|
||||
}
|
||||
|
||||
for _, service := range serviceList {
|
||||
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt service data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return serviceList, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
||||
var service *reverseproxy.Service
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
|
||||
var service *rpservice.Service
|
||||
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -5014,13 +4996,13 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain str
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var serviceList []*reverseproxy.Service
|
||||
var serviceList []*rpservice.Service
|
||||
result := tx.Find(&serviceList)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||
@@ -5036,13 +5018,13 @@ func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength
|
||||
return serviceList, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var serviceList []*reverseproxy.Service
|
||||
var serviceList []*rpservice.Service
|
||||
result := tx.Find(&serviceList, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||
@@ -5270,13 +5252,13 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces
|
||||
return query
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
|
||||
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var target *reverseproxy.Target
|
||||
var target *rpservice.Target
|
||||
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -5289,3 +5271,65 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
|
||||
|
||||
return target, nil
|
||||
}
|
||||
|
||||
// SaveProxy saves or updates a proxy in the database
|
||||
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
result := s.db.WithContext(ctx).Save(p)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save proxy")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND status = ?", proxyID, "connected").
|
||||
Update("last_seen", time.Now())
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update proxy heartbeat")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
var addresses []string
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
|
||||
Distinct("cluster_address").
|
||||
Pluck("cluster_address", &addresses)
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses")
|
||||
}
|
||||
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
|
||||
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
cutoffTime := time.Now().Add(-inactivityDuration)
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Where("last_seen < ?", cutoffTime).
|
||||
Delete(&proxy.Proxy{})
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to cleanup stale proxies")
|
||||
}
|
||||
|
||||
if result.RowsAffected > 0 {
|
||||
log.WithContext(ctx).Infof("Cleaned up %d stale proxies", result.RowsAffected)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -264,7 +264,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
|
||||
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
|
||||
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
|
||||
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||
&types.AccountOnboarding{}, &reverseproxy.Service{}, &reverseproxy.Target{},
|
||||
&types.AccountOnboarding{}, &service.Service{}, &service.Target{},
|
||||
}
|
||||
|
||||
for i := len(models) - 1; i >= 0; i-- {
|
||||
|
||||
@@ -25,9 +25,10 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -252,14 +253,13 @@ type Store interface {
|
||||
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
|
||||
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
|
||||
|
||||
CreateService(ctx context.Context, service *reverseproxy.Service) error
|
||||
UpdateService(ctx context.Context, service *reverseproxy.Service) error
|
||||
CreateService(ctx context.Context, service *rpservice.Service) error
|
||||
UpdateService(ctx context.Context, service *rpservice.Service) error
|
||||
DeleteService(ctx context.Context, accountID, serviceID string) error
|
||||
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error)
|
||||
GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
|
||||
GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error)
|
||||
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error)
|
||||
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
|
||||
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
|
||||
GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
|
||||
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
|
||||
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
|
||||
|
||||
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
|
||||
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
|
||||
@@ -271,12 +271,16 @@ type Store interface {
|
||||
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
|
||||
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
|
||||
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
|
||||
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
|
||||
GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*reverseproxy.Target, error)
|
||||
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error)
|
||||
GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error)
|
||||
DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error
|
||||
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
||||
|
||||
// GetCustomDomainsCounts returns the total and validated custom domain counts.
|
||||
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
}
|
||||
|
||||
|
||||
@@ -12,9 +12,10 @@ import (
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
dns "github.com/netbirdio/netbird/dns"
|
||||
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
domain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
zones "github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
records "github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
types "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -150,6 +151,20 @@ func (mr *MockStoreMockRecorder) ApproveAccountPeers(ctx, accountID interface{})
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveAccountPeers", reflect.TypeOf((*MockStore)(nil).ApproveAccountPeers), ctx, accountID)
|
||||
}
|
||||
|
||||
// CleanupStaleProxies mocks base method.
|
||||
func (m *MockStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CleanupStaleProxies", ctx, inactivityDuration)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CleanupStaleProxies indicates an expected call of CleanupStaleProxies.
|
||||
func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockStore) Close(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -293,7 +308,7 @@ func (mr *MockStoreMockRecorder) CreatePolicy(ctx, policy interface{}) *gomock.C
|
||||
}
|
||||
|
||||
// CreateService mocks base method.
|
||||
func (m *MockStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (m *MockStore) CreateService(ctx context.Context, service *service.Service) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateService", ctx, service)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -1123,10 +1138,10 @@ func (mr *MockStoreMockRecorder) GetAccountRoutes(ctx, lockStrength, accountID i
|
||||
}
|
||||
|
||||
// GetAccountServices mocks base method.
|
||||
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAccountServices", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret0, _ := ret[0].([]*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1227,6 +1242,21 @@ func (mr *MockStoreMockRecorder) GetAccountsCounter(ctx interface{}) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsCounter", reflect.TypeOf((*MockStore)(nil).GetAccountsCounter), ctx)
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddresses mocks base method.
|
||||
func (m *MockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddresses", ctx)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddresses indicates an expected call of GetActiveProxyClusterAddresses.
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
// GetAllAccounts mocks base method.
|
||||
func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1857,10 +1887,10 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
|
||||
ret0, _ := ret[0].(*reverseproxy.Service)
|
||||
ret0, _ := ret[0].(*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1872,10 +1902,10 @@ func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain inter
|
||||
}
|
||||
|
||||
// GetServiceByID mocks base method.
|
||||
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByID", ctx, lockStrength, accountID, serviceID)
|
||||
ret0, _ := ret[0].(*reverseproxy.Service)
|
||||
ret0, _ := ret[0].(*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1887,10 +1917,10 @@ func (mr *MockStoreMockRecorder) GetServiceByID(ctx, lockStrength, accountID, se
|
||||
}
|
||||
|
||||
// GetServiceTargetByTargetID mocks base method.
|
||||
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*reverseproxy.Target, error) {
|
||||
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*service.Target, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceTargetByTargetID", ctx, lockStrength, accountID, targetID)
|
||||
ret0, _ := ret[0].(*reverseproxy.Target)
|
||||
ret0, _ := ret[0].(*service.Target)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1902,10 +1932,10 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a
|
||||
}
|
||||
|
||||
// GetServices mocks base method.
|
||||
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
|
||||
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServices", ctx, lockStrength)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret0, _ := ret[0].([]*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1916,21 +1946,6 @@ func (mr *MockStoreMockRecorder) GetServices(ctx, lockStrength interface{}) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServices", reflect.TypeOf((*MockStore)(nil).GetServices), ctx, lockStrength)
|
||||
}
|
||||
|
||||
// GetServicesByAccountID mocks base method.
|
||||
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetSetupKeyByID mocks base method.
|
||||
func (m *MockStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types2.SetupKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1991,10 +2006,10 @@ func (mr *MockStoreMockRecorder) GetTakenIPs(ctx, lockStrength, accountId interf
|
||||
}
|
||||
|
||||
// GetTargetsByServiceID mocks base method.
|
||||
func (m *MockStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) ([]*reverseproxy.Target, error) {
|
||||
func (m *MockStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) ([]*service.Target, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTargetsByServiceID", ctx, lockStrength, accountID, serviceID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Target)
|
||||
ret0, _ := ret[0].([]*service.Target)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -2610,6 +2625,20 @@ func (mr *MockStoreMockRecorder) SavePostureChecks(ctx, postureCheck interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockStore)(nil).SavePostureChecks), ctx, postureCheck)
|
||||
}
|
||||
|
||||
// SaveProxy mocks base method.
|
||||
func (m *MockStore) SaveProxy(ctx context.Context, proxy *proxy.Proxy) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveProxy", ctx, proxy)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveProxy indicates an expected call of SaveProxy.
|
||||
func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
|
||||
}
|
||||
|
||||
// SaveProxyAccessToken mocks base method.
|
||||
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2805,8 +2834,22 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups)
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
|
||||
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID)
|
||||
}
|
||||
|
||||
// UpdateService mocks base method.
|
||||
func (m *MockStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (m *MockStore) UpdateService(ctx context.Context, service *service.Service) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateService", ctx, service)
|
||||
ret0, _ := ret[0].(error)
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -100,7 +100,7 @@ type Account struct {
|
||||
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
||||
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
|
||||
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"`
|
||||
Services []*service.Service `gorm:"foreignKey:AccountID;references:id"`
|
||||
// Settings is a dictionary of Account settings
|
||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
|
||||
@@ -906,7 +906,7 @@ func (a *Account) Copy() *Account {
|
||||
networkResources = append(networkResources, resource.Copy())
|
||||
}
|
||||
|
||||
services := []*reverseproxy.Service{}
|
||||
services := []*service.Service{}
|
||||
for _, service := range a.Services {
|
||||
services = append(services, service.Copy())
|
||||
}
|
||||
@@ -1814,7 +1814,7 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
|
||||
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
|
||||
for _, target := range service.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
@@ -1823,7 +1823,7 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *rever
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) {
|
||||
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
|
||||
port, ok := a.resolveTargetPort(ctx, target)
|
||||
if !ok {
|
||||
return
|
||||
@@ -1840,7 +1840,7 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *revers
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) {
|
||||
func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) {
|
||||
if target.Port != 0 {
|
||||
return target.Port, true
|
||||
}
|
||||
@@ -1856,7 +1856,7 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Ta
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
|
||||
func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
|
||||
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
|
||||
return &Policy{
|
||||
ID: policyID,
|
||||
|
||||
@@ -42,6 +42,8 @@ var (
|
||||
acmeCerts bool
|
||||
acmeAddr string
|
||||
acmeDir string
|
||||
acmeEABKID string
|
||||
acmeEABHMACKey string
|
||||
acmeChallengeType string
|
||||
debugEndpoint bool
|
||||
debugEndpointAddr string
|
||||
@@ -74,6 +76,8 @@ func init() {
|
||||
rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates automatically")
|
||||
rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges (only used when acme-challenge-type is http-01)")
|
||||
rootCmd.Flags().StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory")
|
||||
rootCmd.Flags().StringVar(&acmeEABKID, "acme-eab-kid", envStringOrDefault("NB_PROXY_ACME_EAB_KID", ""), "ACME EAB KID for account registration")
|
||||
rootCmd.Flags().StringVar(&acmeEABHMACKey, "acme-eab-hmac-key", envStringOrDefault("NB_PROXY_ACME_EAB_HMAC_KEY", ""), "ACME EAB HMAC key for account registration")
|
||||
rootCmd.Flags().StringVar(&acmeChallengeType, "acme-challenge-type", envStringOrDefault("NB_PROXY_ACME_CHALLENGE_TYPE", "tls-alpn-01"), "ACME challenge type: tls-alpn-01 (default, port 443 only) or http-01 (requires port 80)")
|
||||
rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint")
|
||||
rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint")
|
||||
@@ -149,6 +153,8 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
GenerateACMECertificates: acmeCerts,
|
||||
ACMEChallengeAddress: acmeAddr,
|
||||
ACMEDirectory: acmeDir,
|
||||
ACMEEABKID: acmeEABKID,
|
||||
ACMEEABHMACKey: acmeEABHMACKey,
|
||||
ACMEChallengeType: acmeChallengeType,
|
||||
DebugEndpointEnabled: debugEndpoint,
|
||||
DebugEndpointAddress: debugEndpointAddr,
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -59,7 +60,10 @@ type Manager struct {
|
||||
// NewManager creates a new ACME certificate manager. The certDir is used
|
||||
// for caching certificates. The lockMethod controls cross-replica
|
||||
// coordination strategy (see CertLockMethod constants).
|
||||
func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager {
|
||||
// eabKID and eabHMACKey are optional External Account Binding credentials
|
||||
// required for some CAs like ZeroSSL. The eabHMACKey should be the base64
|
||||
// URL-encoded string provided by the CA.
|
||||
func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
@@ -70,10 +74,26 @@ func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *l
|
||||
certNotifier: notifier,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
var eab *acme.ExternalAccountBinding
|
||||
if eabKID != "" && eabHMACKey != "" {
|
||||
decodedKey, err := base64.RawURLEncoding.DecodeString(eabHMACKey)
|
||||
if err != nil {
|
||||
logger.Errorf("failed to decode EAB HMAC key: %v", err)
|
||||
} else {
|
||||
eab = &acme.ExternalAccountBinding{
|
||||
KID: eabKID,
|
||||
Key: decodedKey,
|
||||
}
|
||||
logger.Infof("configured External Account Binding with KID: %s", eabKID)
|
||||
}
|
||||
}
|
||||
|
||||
mgr.Manager = &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
HostPolicy: mgr.hostPolicy,
|
||||
Cache: autocert.DirCache(certDir),
|
||||
ExternalAccountBinding: eab,
|
||||
Client: &acme.Client{
|
||||
DirectoryURL: acmeURL,
|
||||
},
|
||||
@@ -136,7 +156,7 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) {
|
||||
cert, err := mgr.GetCertificate(hello)
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
mgr.logger.Warnf("prefetch certificate for domain %q: %v", name, err)
|
||||
mgr.logger.Warnf("prefetch certificate for domain %q in %s: %v", name, elapsed.String(), err)
|
||||
mgr.setDomainState(d, domainFailed, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func TestHostPolicy(t *testing.T) {
|
||||
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "")
|
||||
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "")
|
||||
mgr.AddDomain("example.com", "acc1", "rp1")
|
||||
|
||||
// Wait for the background prefetch goroutine to finish so the temp dir
|
||||
@@ -70,7 +70,7 @@ func TestHostPolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDomainStates(t *testing.T) {
|
||||
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "")
|
||||
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "")
|
||||
|
||||
assert.Equal(t, 0, mgr.PendingCerts(), "initially zero")
|
||||
assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains")
|
||||
|
||||
@@ -18,8 +18,9 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -37,7 +38,7 @@ type integrationTestSetup struct {
|
||||
grpcServer *grpc.Server
|
||||
grpcAddr string
|
||||
cleanup func()
|
||||
services []*reverseproxy.Service
|
||||
services []*service.Service
|
||||
}
|
||||
|
||||
func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
@@ -66,13 +67,13 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
privKey := base64.StdEncoding.EncodeToString(priv)
|
||||
|
||||
// Create test services in the store
|
||||
services := []*reverseproxy.Service{
|
||||
services := []*service.Service{
|
||||
{
|
||||
ID: "rp-1",
|
||||
AccountID: "test-account-1",
|
||||
Name: "Test App 1",
|
||||
Domain: "app1.test.proxy.io",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Targets: []*service.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "10.0.0.1",
|
||||
Port: 8080,
|
||||
@@ -91,7 +92,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
AccountID: "test-account-1",
|
||||
Name: "Test App 2",
|
||||
Domain: "app2.test.proxy.io",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Targets: []*service.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "10.0.0.2",
|
||||
Port: 8080,
|
||||
@@ -112,7 +113,8 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
}
|
||||
|
||||
// Create real token store
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(5 * time.Minute)
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create real users manager
|
||||
usersManager := users.NewManager(testStore)
|
||||
@@ -124,17 +126,23 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
}
|
||||
|
||||
proxyManager := &testProxyManager{}
|
||||
|
||||
proxyService := nbgrpc.NewProxyServiceServer(
|
||||
&testAccessLogManager{},
|
||||
tokenStore,
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
proxyManager,
|
||||
)
|
||||
|
||||
// Use store-backed service manager
|
||||
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
|
||||
proxyService.SetProxyManager(svcMgr)
|
||||
proxyService.SetServiceManager(svcMgr)
|
||||
|
||||
proxyController := &testProxyController{}
|
||||
proxyService.SetProxyController(proxyController)
|
||||
|
||||
// Start real gRPC server
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
@@ -185,6 +193,52 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
|
||||
return nil, 0, nil
|
||||
}
|
||||
|
||||
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
||||
type testProxyManager struct{}
|
||||
|
||||
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Disconnect(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) Heartbeat(_ context.Context, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// testProxyController is a mock implementation of rpservice.ProxyController for testing.
|
||||
type testProxyController struct{}
|
||||
|
||||
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
|
||||
// noop
|
||||
}
|
||||
|
||||
func (c *testProxyController) GetOIDCValidationConfig() nbproxy.OIDCValidationConfig {
|
||||
return nbproxy.OIDCValidationConfig{}
|
||||
}
|
||||
|
||||
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testProxyController) GetProxiesForCluster(_ string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// storeBackedServiceManager reads directly from the real store.
|
||||
type storeBackedServiceManager struct {
|
||||
store store.Store
|
||||
@@ -195,19 +249,19 @@ func (m *storeBackedServiceManager) DeleteAllServices(ctx context.Context, accou
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
|
||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -219,7 +273,7 @@ func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
|
||||
func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -231,15 +285,15 @@ func (m *storeBackedServiceManager) ReloadService(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1")
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
|
||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
@@ -247,8 +301,8 @@ func (m *storeBackedServiceManager) GetServiceIDByTargetID(ctx context.Context,
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
|
||||
return &reverseproxy.ExposeServiceResponse{}, nil
|
||||
func (m *storeBackedServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||
return &service.ExposeServiceResponse{}, nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||
|
||||
@@ -84,6 +84,10 @@ type Server struct {
|
||||
GenerateACMECertificates bool
|
||||
ACMEChallengeAddress string
|
||||
ACMEDirectory string
|
||||
// ACMEEABKID is the External Account Binding Key ID for CAs that require EAB (e.g., ZeroSSL).
|
||||
ACMEEABKID string
|
||||
// ACMEEABHMACKey is the External Account Binding HMAC key (base64 URL-encoded) for CAs that require EAB.
|
||||
ACMEEABHMACKey string
|
||||
// ACMEChallengeType specifies the ACME challenge type: "http-01" or "tls-alpn-01".
|
||||
// Defaults to "tls-alpn-01" if not specified.
|
||||
ACMEChallengeType string
|
||||
@@ -419,7 +423,7 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
|
||||
"acme_server": s.ACMEDirectory,
|
||||
"challenge_type": s.ACMEChallengeType,
|
||||
}).Debug("ACME certificates enabled, configuring certificate manager")
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod)
|
||||
|
||||
if s.ACMEChallengeType == "http-01" {
|
||||
s.http = &http.Server{
|
||||
|
||||
Reference in New Issue
Block a user