diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 901cdf0e3..c6c41bfe5 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -31,19 +31,15 @@ type store interface { type proxyManager interface { GetActiveClusterAddresses(ctx context.Context) ([]string, error) -} - -type clusterCapabilities interface { - ClusterSupportsCustomPorts(clusterAddr string) *bool - ClusterRequireSubdomain(clusterAddr string) *bool + ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool } type Manager struct { - store store - validator domain.Validator - proxyManager proxyManager - clusterCapabilities clusterCapabilities - permissionsManager permissions.Manager + store store + validator domain.Validator + proxyManager proxyManager + permissionsManager permissions.Manager accountManager account.Manager } @@ -57,11 +53,6 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio } } -// SetClusterCapabilities sets the cluster capabilities provider for domain queries. -func (m *Manager) SetClusterCapabilities(caps clusterCapabilities) { - m.clusterCapabilities = caps -} - func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { @@ -97,10 +88,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d Type: domain.TypeFree, Validated: true, } - if m.clusterCapabilities != nil { - d.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(cluster) - d.RequireSubdomain = m.clusterCapabilities.ClusterRequireSubdomain(cluster) - } + d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster) + d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster) ret = append(ret, d) } @@ -114,8 +103,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d Type: domain.TypeCustom, Validated: d.Validated, } - if m.clusterCapabilities != nil && d.TargetCluster != "" { - cd.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(d.TargetCluster) + if d.TargetCluster != "" { + cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster) } // Custom domains never require a subdomain by default since // the account owns them and should be able to use the bare domain. diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index 9b0de53b4..0368b84de 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -11,11 +11,13 @@ import ( // Manager defines the interface for proxy operations type Manager interface { - Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error Disconnect(ctx context.Context, proxyID string) error Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveClusterAddresses(ctx context.Context) ([]string, error) GetActiveClusters(ctx context.Context) ([]Cluster, error) + ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool CleanupStale(ctx context.Context, inactivityDuration time.Duration) error } @@ -34,6 +36,4 @@ type Controller interface { RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error GetProxiesForCluster(clusterAddr string) []string - ClusterSupportsCustomPorts(clusterAddr string) *bool - ClusterRequireSubdomain(clusterAddr string) *bool } diff --git a/management/internals/modules/reverseproxy/proxy/manager/controller.go b/management/internals/modules/reverseproxy/proxy/manager/controller.go index 05a0c9048..e5b3e9886 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/controller.go +++ b/management/internals/modules/reverseproxy/proxy/manager/controller.go @@ -72,17 +72,6 @@ func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, cluster return nil } -// ClusterSupportsCustomPorts returns whether any proxy in the cluster supports custom ports. -func (c *GRPCController) ClusterSupportsCustomPorts(clusterAddr string) *bool { - return c.proxyGRPCServer.ClusterSupportsCustomPorts(clusterAddr) -} - -// ClusterRequireSubdomain returns whether the cluster requires a subdomain label. -// Returns nil when no proxy has reported the capability (defaults to false). -func (c *GRPCController) ClusterRequireSubdomain(clusterAddr string) *bool { - return c.proxyGRPCServer.ClusterRequireSubdomain(clusterAddr) -} - // GetProxiesForCluster returns all proxy IDs registered for a specific cluster. func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string { proxySet, ok := c.clusterProxies.Load(clusterAddr) diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index dac6d3ce3..a92fffab9 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -16,6 +16,8 @@ type store interface { UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error } @@ -38,9 +40,14 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) { }, nil } -// Connect registers a new proxy connection in the database -func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +// Connect registers a new proxy connection in the database. +// capabilities may be nil for old proxies that do not report them. +func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error { now := time.Now() + var caps proxy.Capabilities + if capabilities != nil { + caps = *capabilities + } p := &proxy.Proxy{ ID: proxyID, ClusterAddress: clusterAddress, @@ -48,6 +55,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress LastSeen: now, ConnectedAt: &now, Status: "connected", + Capabilities: caps, } if err := m.store.SaveProxy(ctx, p); err != nil { @@ -118,6 +126,18 @@ func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) return clusters, nil } +// ClusterSupportsCustomPorts returns whether any active proxy in the cluster +// supports custom ports. Returns nil when no proxy has reported capabilities. +func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + return m.store.GetClusterSupportsCustomPorts(ctx, clusterAddr) +} + +// ClusterRequireSubdomain returns whether any active proxy in the cluster +// requires a subdomain. Returns nil when no proxy has reported capabilities. +func (m Manager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + return m.store.GetClusterRequireSubdomain(ctx, clusterAddr) +} + // CleanupStale removes proxies that haven't sent heartbeat in the specified duration func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index da3df12a2..97466c503 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -50,18 +50,46 @@ func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration) } -// Connect mocks base method. -func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +// ClusterSupportsCustomPorts mocks base method. +func (m *MockManager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress) + ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts. +func (mr *MockManagerMockRecorder) ClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCustomPorts), ctx, clusterAddr) +} + +// ClusterRequireSubdomain mocks base method. +func (m *MockManager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ClusterRequireSubdomain", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain. +func (mr *MockManagerMockRecorder) ClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockManager)(nil).ClusterRequireSubdomain), ctx, clusterAddr) +} + +// Connect mocks base method. +func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities) ret0, _ := ret[0].(error) return ret0 } // Connect indicates an expected call of Connect. -func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities) } // Disconnect mocks base method. @@ -145,34 +173,6 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder { return m.recorder } -// ClusterSupportsCustomPorts mocks base method. -func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", clusterAddr) - ret0, _ := ret[0].(*bool) - return ret0 -} - -// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts. -func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr) -} - -// ClusterRequireSubdomain mocks base method. -func (m *MockController) ClusterRequireSubdomain(clusterAddr string) *bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ClusterRequireSubdomain", clusterAddr) - ret0, _ := ret[0].(*bool) - return ret0 -} - -// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain. -func (mr *MockControllerMockRecorder) ClusterRequireSubdomain(clusterAddr interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockController)(nil).ClusterRequireSubdomain), clusterAddr) -} - // GetOIDCValidationConfig mocks base method. func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 671eb109f..4102e50fe 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -2,6 +2,17 @@ package proxy import "time" +// Capabilities describes what a proxy can handle, as reported via gRPC. +// Nil fields mean the proxy never reported this capability. +type Capabilities struct { + // SupportsCustomPorts indicates whether this proxy can bind arbitrary + // ports for TCP/UDP services. TLS uses SNI routing and is not gated. + SupportsCustomPorts *bool + // RequireSubdomain indicates whether a subdomain label is required in + // front of the cluster domain. + RequireSubdomain *bool +} + // Proxy represents a reverse proxy instance type Proxy struct { ID string `gorm:"primaryKey;type:varchar(255)"` @@ -11,6 +22,7 @@ type Proxy struct { ConnectedAt *time.Time DisconnectedAt *time.Time Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` + Capabilities Capabilities `gorm:"embedded"` CreatedAt time.Time UpdatedAt time.Time } diff --git a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go index 8b652c7e1..4a7647d90 100644 --- a/management/internals/modules/reverseproxy/service/manager/l4_port_test.go +++ b/management/internals/modules/reverseproxy/service/manager/l4_port_test.go @@ -75,11 +75,13 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor require.NoError(t, err) mockCtrl := proxy.NewMockController(ctrl) - mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes() - mockCtrl.EXPECT().ClusterRequireSubdomain(gomock.Any()).Return((*bool)(nil)).AnyTimes() mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes() + mockCaps := proxy.NewMockManager(ctrl) + mockCaps.EXPECT().ClusterSupportsCustomPorts(gomock.Any(), testCluster).Return(customPortsSupported).AnyTimes() + mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), testCluster).Return((*bool)(nil)).AnyTimes() + accountMgr := &mock_server.MockAccountManager{ StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, @@ -93,6 +95,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor accountManager: accountMgr, permissionsManager: permissions.NewManager(testStore), proxyController: mockCtrl, + capabilities: mockCaps, clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}}, } mgr.exposeReaper = &exposeReaper{manager: mgr} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index ea4fa9d1e..db393ef38 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -75,22 +75,30 @@ type ClusterDeriver interface { GetClusterDomains() []string } +// CapabilityProvider queries proxy cluster capabilities from the database. +type CapabilityProvider interface { + ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool +} + type Manager struct { store store.Store accountManager account.Manager permissionsManager permissions.Manager proxyController proxy.Controller + capabilities CapabilityProvider clusterDeriver ClusterDeriver exposeReaper *exposeReaper } // NewManager creates a new service manager. -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager { +func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager { mgr := &Manager{ store: store, accountManager: accountManager, permissionsManager: permissionsManager, proxyController: proxyController, + capabilities: capabilities, clusterDeriver: clusterDeriver, } mgr.exposeReaper = &exposeReaper{manager: mgr} @@ -237,7 +245,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri } service.ProxyCluster = proxyCluster - if err := m.validateSubdomainRequirement(service.Domain, proxyCluster); err != nil { + if err := m.validateSubdomainRequirement(ctx, service.Domain, proxyCluster); err != nil { return err } } @@ -268,11 +276,11 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri // validateSubdomainRequirement checks whether the domain can be used bare // (without a subdomain label) on the given cluster. If the cluster reports // require_subdomain=true and the domain equals the cluster domain, it rejects. -func (m *Manager) validateSubdomainRequirement(domain, cluster string) error { +func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, cluster string) error { if domain != cluster { return nil } - requireSub := m.proxyController.ClusterRequireSubdomain(cluster) + requireSub := m.capabilities.ClusterRequireSubdomain(ctx, cluster) if requireSub != nil && *requireSub { return status.Errorf(status.InvalidArgument, "domain %s requires a subdomain label", domain) } @@ -312,7 +320,7 @@ func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service if !service.IsL4Protocol(svc.Mode) { return nil } - customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster) + customPorts := m.capabilities.ClusterSupportsCustomPorts(ctx, svc.ProxyCluster) if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) { if svc.Source != service.SourceEphemeral { return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster) @@ -520,12 +528,12 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St } if existingService.Terminated { - return status.Errorf(status.PermissionDenied, "service is terminated and cannot be updated") - } + return status.Errorf(status.PermissionDenied, "service is terminated and cannot be updated") + } - if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil { - return err - } + if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil { + return err + } updateInfo.oldCluster = existingService.ProxyCluster updateInfo.domainChanged = existingService.Domain != service.Domain @@ -538,7 +546,7 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St service.ProxyCluster = existingService.ProxyCluster } - if err := m.validateSubdomainRequirement(service.Domain, service.ProxyCluster); err != nil { + if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil { return err } diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 18e1be26e..f6e532118 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -1324,11 +1324,11 @@ func TestValidateSubdomainRequirement(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctrl := gomock.NewController(t) - mockCtrl := proxy.NewMockController(ctrl) - mockCtrl.EXPECT().ClusterRequireSubdomain(tc.cluster).Return(tc.requireSubdomain).AnyTimes() + mockCaps := proxy.NewMockManager(ctrl) + mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), tc.cluster).Return(tc.requireSubdomain).AnyTimes() - mgr := &Manager{proxyController: mockCtrl} - err := mgr.validateSubdomainRequirement(tc.domain, tc.cluster) + mgr := &Manager{capabilities: mockCaps} + err := mgr.validateSubdomainRequirement(context.Background(), tc.domain, tc.cluster) if tc.wantErr { require.Error(t, err) assert.Contains(t, err.Error(), "requires a subdomain label") diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index a32cf6046..6064bd5b6 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -195,7 +195,7 @@ func (s *BaseServer) RecordsManager() records.Manager { func (s *BaseServer) ServiceManager() service.Manager { return Create(s, func() service.Manager { - return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager()) + return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager()) }) } @@ -212,9 +212,6 @@ func (s *BaseServer) ProxyManager() proxy.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { return Create(s, func() *manager.Manager { m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager()) - s.AfterInit(func(s *BaseServer) { - m.SetClusterCapabilities(s.ServiceProxyController()) - }) return &m }) } diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 5fa382af0..07732cea6 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -182,9 +182,21 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) } - // Register proxy in database - if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil { - log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err) + // Register proxy in database with capabilities + var caps *proxy.Capabilities + if c := req.GetCapabilities(); c != nil { + caps = &proxy.Capabilities{ + SupportsCustomPorts: c.SupportsCustomPorts, + RequireSubdomain: c.RequireSubdomain, + } + } + if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil { + log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) + s.connectedProxies.Delete(proxyID) + if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { + log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr) + } + return status.Errorf(codes.Internal, "register proxy in database: %v", err) } log.WithFields(log.Fields{ @@ -297,6 +309,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn * } m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig()) + if !proxyAcceptsMapping(conn, m) { + continue + } mappings = append(mappings, m) } return mappings, nil @@ -445,22 +460,46 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd log.Debugf("Sending service update to cluster %s", clusterAddr) for _, proxyID := range proxyIDs { - if connVal, ok := s.connectedProxies.Load(proxyID); ok { - conn := connVal.(*proxyConnection) - msg := s.perProxyMessage(updateResponse, proxyID) - if msg == nil { - continue - } - select { - case conn.sendChan <- msg: - log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) - default: - log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) - } + connVal, ok := s.connectedProxies.Load(proxyID) + if !ok { + continue + } + conn := connVal.(*proxyConnection) + if !proxyAcceptsMapping(conn, update) { + log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id) + continue + } + msg := s.perProxyMessage(updateResponse, proxyID) + if msg == nil { + continue + } + select { + case conn.sendChan <- msg: + log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) + default: + log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) } } } +// proxyAcceptsMapping returns whether the proxy should receive this mapping. +// Old proxies that never reported capabilities are skipped for non-TLS L4 +// mappings with a custom listen port, since they don't understand the +// protocol. Proxies that report capabilities (even SupportsCustomPorts=false) +// are new enough to handle the mapping. TLS uses SNI routing and works on +// any proxy. Delete operations are always sent so proxies can clean up. +func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool { + if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED { + return true + } + if mapping.ListenPort == 0 || mapping.Mode == "tls" { + return true + } + // Old proxies that never reported capabilities don't understand + // custom port mappings. + return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil +} + // perProxyMessage returns a copy of update with a fresh one-time token for // create/update operations. For delete operations the original mapping is // used unchanged because proxies do not need to authenticate for removal. @@ -508,64 +547,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { } } -// ClusterSupportsCustomPorts returns whether any connected proxy in the given -// cluster reports custom port support. Returns nil if no proxy has reported -// capabilities (old proxies that predate the field). -func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool { - if s.proxyController == nil { - return nil - } - - var hasCapabilities bool - for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) { - connVal, ok := s.connectedProxies.Load(pid) - if !ok { - continue - } - conn := connVal.(*proxyConnection) - if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil { - continue - } - if *conn.capabilities.SupportsCustomPorts { - return ptr(true) - } - hasCapabilities = true - } - if hasCapabilities { - return ptr(false) - } - return nil -} - -// ClusterRequireSubdomain returns whether any connected proxy in the given -// cluster reports that a subdomain is required. Returns nil if no proxy has -// reported the capability (defaults to not required). -func (s *ProxyServiceServer) ClusterRequireSubdomain(clusterAddr string) *bool { - if s.proxyController == nil { - return nil - } - - var hasCapabilities bool - for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) { - connVal, ok := s.connectedProxies.Load(pid) - if !ok { - continue - } - conn := connVal.(*proxyConnection) - if conn.capabilities == nil || conn.capabilities.RequireSubdomain == nil { - continue - } - if *conn.capabilities.RequireSubdomain { - return ptr(true) - } - hasCapabilities = true - } - if hasCapabilities { - return ptr(false) - } - return nil -} - func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index 83c99020d..d5aed3dee 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -53,14 +53,6 @@ func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clus return nil } -func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool { - return ptr(true) -} - -func (c *testProxyController) ClusterRequireSubdomain(_ string) *bool { - return nil -} - func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string { c.mu.Lock() defer c.mu.Unlock() @@ -355,14 +347,14 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { const cluster = "proxy.example.com" - // Proxy A supports custom ports. - chA := registerFakeProxyWithCaps(s, "proxy-a", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) - // Proxy B does NOT support custom ports (shared cloud proxy). - chB := registerFakeProxyWithCaps(s, "proxy-b", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + // Modern proxy reports capabilities. + chModern := registerFakeProxyWithCaps(s, "proxy-modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) + // Legacy proxy never reported capabilities (nil). + chLegacy := registerFakeProxy(s, "proxy-legacy", cluster) ctx := context.Background() - // TLS passthrough works on all proxies regardless of custom port support. + // TLS passthrough with custom port: all proxies receive it (SNI routing). tlsMapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Id: "service-tls", @@ -375,12 +367,26 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster) - msgA := drainMapping(chA) - msgB := drainMapping(chB) - assert.NotNil(t, msgA, "proxy-a should receive TLS mapping") - assert.NotNil(t, msgB, "proxy-b should receive TLS mapping (passthrough works on all proxies)") + assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TLS mapping") + assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive TLS mapping (SNI works on all)") - // Send an HTTP mapping: both should receive it. + // TCP mapping with custom port: only modern proxy receives it. + tcpMapping := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-tcp", + AccountId: "account-1", + Domain: "db.example.com", + Mode: "tcp", + ListenPort: 5432, + Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}}, + } + + s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster) + + assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TCP custom-port mapping") + assert.Nil(t, drainMapping(chLegacy), "legacy proxy should NOT receive TCP custom-port mapping") + + // HTTP mapping (no listen port): both receive it. httpMapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Id: "service-http", @@ -391,10 +397,16 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { s.SendServiceUpdateToCluster(ctx, httpMapping, cluster) - msgA = drainMapping(chA) - msgB = drainMapping(chB) - assert.NotNil(t, msgA, "proxy-a should receive HTTP mapping") - assert.NotNil(t, msgB, "proxy-b should receive HTTP mapping") + assert.NotNil(t, drainMapping(chModern), "modern proxy should receive HTTP mapping") + assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive HTTP mapping") + + // Proxy that reports SupportsCustomPorts=false still receives custom-port + // mappings because it understands the protocol (it's new enough). + chNewNoCustom := registerFakeProxyWithCaps(s, "proxy-new-no-custom", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + + s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster) + + assert.NotNil(t, drainMapping(chNewNoCustom), "new proxy with SupportsCustomPorts=false should still receive mapping") } func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { @@ -408,7 +420,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { const cluster = "proxy.example.com" - chShared := registerFakeProxyWithCaps(s, "proxy-shared", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + // Legacy proxy (no capabilities) still receives TLS since it uses SNI. + chLegacy := registerFakeProxy(s, "proxy-legacy", cluster) tlsMapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, @@ -421,8 +434,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster) - msg := drainMapping(chShared) - assert.NotNil(t, msg, "shared proxy should receive TLS mapping even without custom port support") + msg := drainMapping(chLegacy) + assert.NotNil(t, msg, "legacy proxy should receive TLS mapping (SNI works without custom port support)") } // TestServiceModifyNotifications exercises every possible modification @@ -589,7 +602,7 @@ func TestServiceModifyNotifications(t *testing.T) { s.SetProxyController(newTestProxyController()) const cluster = "proxy.example.com" chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)}) - chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + chLegacy := registerFakeProxy(s, "legacy", cluster) // TLS passthrough works on all proxies regardless of custom port support s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster) @@ -608,7 +621,7 @@ func TestServiceModifyNotifications(t *testing.T) { } s.SetProxyController(newTestProxyController()) const cluster = "proxy.example.com" - chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)}) + chLegacy := registerFakeProxy(s, "legacy", cluster) mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED) mapping.ListenPort = 0 // default port diff --git a/management/server/account_test.go b/management/server/account_test.go index fdec43617..548cf31d4 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3138,7 +3138,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU if err != nil { return nil, nil, err } - manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil)) + manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil)) return manager, updateManager, nil } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 55095bbb7..c6e57b1be 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -114,8 +114,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee if err != nil { t.Fatalf("Failed to create proxy controller: %v", err) } - domainManager.SetClusterCapabilities(serviceProxyController) - serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager) + serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager) proxyServiceServer.SetServiceManager(serviceManager) am.SetServiceManager(serviceManager) @@ -244,8 +243,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin if err != nil { t.Fatalf("Failed to create proxy controller: %v", err) } - domainManager.SetClusterCapabilities(serviceProxyController) - serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager) + serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, proxyMgr, domainManager) proxyServiceServer.SetServiceManager(serviceManager) am.SetServiceManager(serviceManager) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index cf030f51e..ee1947b18 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -5445,7 +5445,7 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string result := s.db.WithContext(ctx). Model(&proxy.Proxy{}). - Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)). + Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). Distinct("cluster_address"). Pluck("cluster_address", &addresses) @@ -5463,7 +5463,7 @@ func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, result := s.db.Model(&proxy.Proxy{}). Select("cluster_address as address, COUNT(*) as connected_proxies"). - Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)). + Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). Group("cluster_address"). Scan(&clusters) @@ -5475,6 +5475,63 @@ func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, return clusters, nil } +// proxyActiveThreshold is the maximum age of a heartbeat for a proxy to be +// considered active. Must be at least 2x the heartbeat interval (1 min). +const proxyActiveThreshold = 2 * time.Minute + +var validCapabilityColumns = map[string]struct{}{ + "supports_custom_ports": {}, + "require_subdomain": {}, +} + +// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster +// supports custom ports. Returns nil when no proxy reported the capability. +func (s *SqlStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + return s.getClusterCapability(ctx, clusterAddr, "supports_custom_ports") +} + +// GetClusterRequireSubdomain returns whether any active proxy in the cluster +// requires a subdomain. Returns nil when no proxy reported the capability. +func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + return s.getClusterCapability(ctx, clusterAddr, "require_subdomain") +} + +// getClusterCapability returns an aggregated boolean capability for the given +// cluster. It checks active (connected, recently seen) proxies and returns: +// - *true if any proxy in the cluster has the capability set to true, +// - *false if at least one proxy reported but none set it to true, +// - nil if no proxy reported the capability at all. +func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column string) *bool { + if _, ok := validCapabilityColumns[column]; !ok { + log.WithContext(ctx).Errorf("invalid capability column: %s", column) + return nil + } + + var result struct { + HasCapability bool + AnyTrue bool + } + + err := s.db.WithContext(ctx). + Model(&proxy.Proxy{}). + Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+ + "COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true"). + Where("cluster_address = ? AND status = ? AND last_seen > ?", + clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)). + Scan(&result).Error + + if err != nil { + log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err) + return nil + } + + if !result.HasCapability { + return nil + } + + return &result.AnyTrue +} + // CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error { cutoffTime := time.Now().Add(-inactivityDuration) diff --git a/management/server/store/store.go b/management/server/store/store.go index d00dcde38..e24a1efef 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -287,6 +287,8 @@ type Store interface { UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool + GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 235405861..a8648aed7 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -165,6 +165,34 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration) } +// GetClusterSupportsCustomPorts mocks base method. +func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts. +func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr) +} + +// GetClusterRequireSubdomain mocks base method. +func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr) + ret0, _ := ret[0].(*bool) + return ret0 +} + +// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain. +func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr) +} + // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index c30234b5a..796cad622 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -200,7 +200,7 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, // testProxyManager is a mock implementation of proxy.Manager for testing. type testProxyManager struct{} -func (m *testProxyManager) Connect(_ context.Context, _, _, _ string) error { +func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error { return nil } @@ -220,6 +220,14 @@ func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Clust return nil, nil } +func (m *testProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} + +func (m *testProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} + func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { return nil } @@ -247,14 +255,6 @@ func (c *testProxyController) GetProxiesForCluster(_ string) []string { return nil } -func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool { - return nil -} - -func (c *testProxyController) ClusterRequireSubdomain(_ string) *bool { - return nil -} - // storeBackedServiceManager reads directly from the real store. type storeBackedServiceManager struct { store store.Store