From 8e7b016be2dc14fe1e64a8f7abc25ca7c848bbd9 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 4 Mar 2026 18:15:13 +0100 Subject: [PATCH] [management] Replace in-memory expose tracker with SQL-backed operations (#5494) The expose tracker used sync.Map for in-memory TTL tracking of active expose sessions, which broke and lost all sessions on restart. Replace with SQL-backed operations that reuse the existing meta_last_renewed_at column: - Add store methods: RenewEphemeralService, GetExpiredEphemeralServices, CountEphemeralServicesByPeer, EphemeralServiceExists - Move duplicate/limit checks inside a transaction with row-level locking (SELECT ... FOR UPDATE) to prevent concurrent bypass - Reaper re-checks expiry under row lock to avoid deleting a just-renewed service and prevent duplicate event emission - Add composite index on (source, source_peer) for efficient queries - Batch-limit and column-select the reaper query to avoid DB/GC spikes - Filter out malformed rows with empty source_peer --- .../service/manager/expose_tracker.go | 154 ++------ .../service/manager/expose_tracker_test.go | 338 ++++++++---------- .../reverseproxy/service/manager/manager.go | 150 +++++--- .../service/manager/manager_test.go | 34 +- .../modules/reverseproxy/service/service.go | 4 +- management/server/store/sql_store.go | 93 +++++ management/server/store/store.go | 5 + management/server/store/store_mock.go | 59 +++ 8 files changed, 461 insertions(+), 376 deletions(-) diff --git a/management/internals/modules/reverseproxy/service/manager/expose_tracker.go b/management/internals/modules/reverseproxy/service/manager/expose_tracker.go index 11e1f0110..911add3bb 100644 --- a/management/internals/modules/reverseproxy/service/manager/expose_tracker.go +++ b/management/internals/modules/reverseproxy/service/manager/expose_tracker.go @@ -2,7 +2,7 @@ package manager import ( "context" - "sync" + "math/rand/v2" "time" "github.com/netbirdio/netbird/shared/management/status" @@ -13,108 +13,20 @@ const ( exposeTTL = 90 * time.Second exposeReapInterval = 30 * time.Second maxExposesPerPeer = 10 + exposeReapBatch = 100 ) -type trackedExpose struct { - mu sync.Mutex - domain string - accountID string - peerID string - lastRenewed time.Time - expiring bool +type exposeReaper struct { + manager *Manager } -type exposeTracker struct { - activeExposes sync.Map - exposeCreateMu sync.Mutex - manager *Manager -} - -func exposeKey(peerID, domain string) string { - return peerID + ":" + domain -} - -// TrackExposeIfAllowed atomically checks the per-peer limit and registers a new -// active expose session under the same lock. Returns (true, false) if the expose -// was already tracked (duplicate), (false, true) if tracking succeeded, and -// (false, false) if the peer has reached the limit. -func (t *exposeTracker) TrackExposeIfAllowed(peerID, domain, accountID string) (alreadyTracked, ok bool) { - t.exposeCreateMu.Lock() - defer t.exposeCreateMu.Unlock() - - key := exposeKey(peerID, domain) - _, loaded := t.activeExposes.LoadOrStore(key, &trackedExpose{ - domain: domain, - accountID: accountID, - peerID: peerID, - lastRenewed: time.Now(), - }) - if loaded { - return true, false - } - - if t.CountPeerExposes(peerID) > maxExposesPerPeer { - t.activeExposes.Delete(key) - return false, false - } - - return false, true -} - -// UntrackExpose removes an active expose session from tracking. -func (t *exposeTracker) UntrackExpose(peerID, domain string) { - t.activeExposes.Delete(exposeKey(peerID, domain)) -} - -// CountPeerExposes returns the number of active expose sessions for a peer. -func (t *exposeTracker) CountPeerExposes(peerID string) int { - count := 0 - t.activeExposes.Range(func(_, val any) bool { - if expose := val.(*trackedExpose); expose.peerID == peerID { - count++ - } - return true - }) - return count -} - -// MaxExposesPerPeer returns the maximum number of concurrent exposes allowed per peer. -func (t *exposeTracker) MaxExposesPerPeer() int { - return maxExposesPerPeer -} - -// RenewTrackedExpose updates the in-memory lastRenewed timestamp for a tracked expose. -// Returns false if the expose is not tracked or is being reaped. -func (t *exposeTracker) RenewTrackedExpose(peerID, domain string) bool { - key := exposeKey(peerID, domain) - val, ok := t.activeExposes.Load(key) - if !ok { - return false - } - - expose := val.(*trackedExpose) - expose.mu.Lock() - if expose.expiring { - expose.mu.Unlock() - return false - } - expose.lastRenewed = time.Now() - expose.mu.Unlock() - - return true -} - -// StopTrackedExpose removes an active expose session from tracking. -// Returns false if the expose was not tracked. -func (t *exposeTracker) StopTrackedExpose(peerID, domain string) bool { - key := exposeKey(peerID, domain) - _, ok := t.activeExposes.LoadAndDelete(key) - return ok -} - -// StartExposeReaper starts a background goroutine that reaps expired expose sessions. -func (t *exposeTracker) StartExposeReaper(ctx context.Context) { +// StartExposeReaper starts a background goroutine that reaps expired ephemeral services from the DB. +func (r *exposeReaper) StartExposeReaper(ctx context.Context) { go func() { + // start with a random delay + rn := rand.IntN(10) + time.Sleep(time.Duration(rn) * time.Second) + ticker := time.NewTicker(exposeReapInterval) defer ticker.Stop() @@ -123,41 +35,31 @@ func (t *exposeTracker) StartExposeReaper(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - t.reapExpiredExposes() + r.reapExpiredExposes(ctx) } } }() } -func (t *exposeTracker) reapExpiredExposes() { - t.activeExposes.Range(func(key, val any) bool { - expose := val.(*trackedExpose) - expose.mu.Lock() - expired := time.Since(expose.lastRenewed) > exposeTTL - if expired { - expose.expiring = true - } - expose.mu.Unlock() +func (r *exposeReaper) reapExpiredExposes(ctx context.Context) { + expired, err := r.manager.store.GetExpiredEphemeralServices(ctx, exposeTTL, exposeReapBatch) + if err != nil { + log.Errorf("failed to get expired ephemeral services: %v", err) + return + } - if !expired { - return true + for _, svc := range expired { + log.Infof("reaping expired expose session for peer %s, domain %s", svc.SourcePeer, svc.Domain) + + err := r.manager.deleteExpiredPeerService(ctx, svc.AccountID, svc.SourcePeer, svc.ID) + if err == nil { + continue } - log.Infof("reaping expired expose session for peer %s, domain %s", expose.peerID, expose.domain) - - err := t.manager.deleteServiceFromPeer(context.Background(), expose.accountID, expose.peerID, expose.domain, true) - - s, _ := status.FromError(err) - - switch { - case err == nil: - t.activeExposes.Delete(key) - case s.ErrorType == status.NotFound: - log.Debugf("service %s was already deleted", expose.domain) - default: - log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", expose.domain, err) + if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound { + log.Debugf("service %s was already deleted by another instance", svc.Domain) + } else { + log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", svc.Domain, err) } - - return true - }) + } } diff --git a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go index 154239fb1..bd9f4b93b 100644 --- a/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go +++ b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go @@ -10,184 +10,62 @@ import ( "github.com/stretchr/testify/require" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/server/store" ) -func TestExposeKey(t *testing.T) { - assert.Equal(t, "peer1:example.com", exposeKey("peer1", "example.com")) - assert.Equal(t, "peer2:other.com", exposeKey("peer2", "other.com")) - assert.NotEqual(t, exposeKey("peer1", "a.com"), exposeKey("peer1", "b.com")) -} - -func TestTrackExposeIfAllowed(t *testing.T) { - t.Run("first track succeeds", func(t *testing.T) { - tracker := &exposeTracker{} - alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") - assert.False(t, alreadyTracked, "first track should not be duplicate") - assert.True(t, ok, "first track should be allowed") - }) - - t.Run("duplicate track detected", func(t *testing.T) { - tracker := &exposeTracker{} - tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") - - alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") - assert.True(t, alreadyTracked, "second track should be duplicate") - assert.False(t, ok) - }) - - t.Run("rejects when at limit", func(t *testing.T) { - tracker := &exposeTracker{} - for i := range maxExposesPerPeer { - _, ok := tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1") - assert.True(t, ok, "track %d should be allowed", i) - } - - alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "over-limit.com", "acct1") - assert.False(t, alreadyTracked) - assert.False(t, ok, "should reject when at limit") - }) - - t.Run("other peer unaffected by limit", func(t *testing.T) { - tracker := &exposeTracker{} - for i := range maxExposesPerPeer { - tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1") - } - - _, ok := tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1") - assert.True(t, ok, "other peer should still be within limit") - }) -} - -func TestUntrackExpose(t *testing.T) { - tracker := &exposeTracker{} - - tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") - assert.Equal(t, 1, tracker.CountPeerExposes("peer1")) - - tracker.UntrackExpose("peer1", "a.com") - assert.Equal(t, 0, tracker.CountPeerExposes("peer1")) -} - -func TestCountPeerExposes(t *testing.T) { - tracker := &exposeTracker{} - - assert.Equal(t, 0, tracker.CountPeerExposes("peer1")) - - tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") - tracker.TrackExposeIfAllowed("peer1", "b.com", "acct1") - tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1") - - assert.Equal(t, 2, tracker.CountPeerExposes("peer1"), "peer1 should have 2 exposes") - assert.Equal(t, 1, tracker.CountPeerExposes("peer2"), "peer2 should have 1 expose") - assert.Equal(t, 0, tracker.CountPeerExposes("peer3"), "peer3 should have 0 exposes") -} - -func TestMaxExposesPerPeer(t *testing.T) { - tracker := &exposeTracker{} - assert.Equal(t, maxExposesPerPeer, tracker.MaxExposesPerPeer()) -} - -func TestRenewTrackedExpose(t *testing.T) { - tracker := &exposeTracker{} - - found := tracker.RenewTrackedExpose("peer1", "a.com") - assert.False(t, found, "should not find untracked expose") - - tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") - - found = tracker.RenewTrackedExpose("peer1", "a.com") - assert.True(t, found, "should find tracked expose") -} - -func TestRenewTrackedExpose_RejectsExpiring(t *testing.T) { - tracker := &exposeTracker{} - tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") - - // Simulate reaper marking the expose as expiring - key := exposeKey("peer1", "a.com") - val, _ := tracker.activeExposes.Load(key) - expose := val.(*trackedExpose) - expose.mu.Lock() - expose.expiring = true - expose.mu.Unlock() - - found := tracker.RenewTrackedExpose("peer1", "a.com") - assert.False(t, found, "should reject renewal when expiring") -} - func TestReapExpiredExposes(t *testing.T) { - mgr, _ := setupIntegrationTest(t) - tracker := mgr.exposeTracker - + mgr, testStore := setupIntegrationTest(t) ctx := context.Background() + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ Port: 8080, Protocol: "http", }) require.NoError(t, err) - // Manually expire the tracked entry - key := exposeKey(testPeerID, resp.Domain) - val, _ := tracker.activeExposes.Load(key) - expose := val.(*trackedExpose) - expose.mu.Lock() - expose.lastRenewed = time.Now().Add(-2 * exposeTTL) - expose.mu.Unlock() + // Manually expire the service by backdating meta_last_renewed_at + expireEphemeralService(t, testStore, testAccountID, resp.Domain) - // Add an active (non-expired) tracking entry - tracker.activeExposes.Store(exposeKey("peer1", "active.com"), &trackedExpose{ - domain: "active.com", - accountID: testAccountID, - peerID: "peer1", - lastRenewed: time.Now(), + // Create a non-expired service + resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8081, + Protocol: "http", }) + require.NoError(t, err) - tracker.reapExpiredExposes() + mgr.exposeReaper.reapExpiredExposes(ctx) - _, exists := tracker.activeExposes.Load(key) - assert.False(t, exists, "expired expose should be removed") + // Expired service should be deleted + _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + require.Error(t, err, "expired service should be deleted") - _, exists = tracker.activeExposes.Load(exposeKey("peer1", "active.com")) - assert.True(t, exists, "active expose should remain") + // Non-expired service should remain + _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp2.Domain) + require.NoError(t, err, "active service should remain") } -func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) { - mgr, _ := setupIntegrationTest(t) - tracker := mgr.exposeTracker - +func TestReapAlreadyDeletedService(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) ctx := context.Background() + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ Port: 8080, Protocol: "http", }) require.NoError(t, err) - key := exposeKey(testPeerID, resp.Domain) - val, _ := tracker.activeExposes.Load(key) - expose := val.(*trackedExpose) + expireEphemeralService(t, testStore, testAccountID, resp.Domain) - // Expire it - expose.mu.Lock() - expose.lastRenewed = time.Now().Add(-2 * exposeTTL) - expose.mu.Unlock() + // Delete the service before reaping + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + require.NoError(t, err) - // Renew should succeed before reaping - assert.True(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should succeed before reaper runs") - - // Re-expire and reap - expose.mu.Lock() - expose.lastRenewed = time.Now().Add(-2 * exposeTTL) - expose.mu.Unlock() - - tracker.reapExpiredExposes() - - // Entry is deleted, renew returns false - assert.False(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should fail after reap") + // Reaping should handle the already-deleted service gracefully + mgr.exposeReaper.reapExpiredExposes(ctx) } -func TestConcurrentTrackAndCount(t *testing.T) { - mgr, _ := setupIntegrationTest(t) - tracker := mgr.exposeTracker +func TestConcurrentReapAndRenew(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) ctx := context.Background() for i := range 5 { @@ -198,59 +76,133 @@ func TestConcurrentTrackAndCount(t *testing.T) { require.NoError(t, err) } - // Manually expire all tracked entries - tracker.activeExposes.Range(func(_, val any) bool { - expose := val.(*trackedExpose) - expose.mu.Lock() - expose.lastRenewed = time.Now().Add(-2 * exposeTTL) - expose.mu.Unlock() - return true - }) - - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - tracker.reapExpiredExposes() - }() - go func() { - defer wg.Done() - tracker.CountPeerExposes(testPeerID) - }() - wg.Wait() - - assert.Equal(t, 0, tracker.CountPeerExposes(testPeerID), "all expired exposes should be reaped") -} - -func TestTrackedExposeMutexProtectsLastRenewed(t *testing.T) { - expose := &trackedExpose{ - lastRenewed: time.Now().Add(-1 * time.Hour), + // Expire all services + services, err := testStore.GetAccountServices(ctx, store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + for _, svc := range services { + if svc.Source == rpservice.SourceEphemeral { + expireEphemeralService(t, testStore, testAccountID, svc.Domain) + } } var wg sync.WaitGroup wg.Add(2) - go func() { defer wg.Done() - for range 100 { - expose.mu.Lock() - expose.lastRenewed = time.Now() - expose.mu.Unlock() - } + mgr.exposeReaper.reapExpiredExposes(ctx) }() - go func() { defer wg.Done() - for range 100 { - expose.mu.Lock() - _ = time.Since(expose.lastRenewed) - expose.mu.Unlock() - } + _, _ = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) }() - wg.Wait() - expose.mu.Lock() - require.False(t, expose.lastRenewed.IsZero(), "lastRenewed should not be zero after concurrent access") - expose.mu.Unlock() + count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(0), count, "all expired services should be reaped") +} + +func TestRenewEphemeralService(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + ctx := context.Background() + + t.Run("renew succeeds for active service", func(t *testing.T) { + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8082, + Protocol: "http", + }) + require.NoError(t, err) + + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + require.NoError(t, err) + }) + + t.Run("renew fails for nonexistent domain", func(t *testing.T) { + err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com") + require.Error(t, err) + assert.Contains(t, err.Error(), "no active expose session") + }) +} + +func TestCountAndExistsEphemeralServices(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + ctx := context.Background() + + count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(0), count) + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8083, + Protocol: "http", + }) + require.NoError(t, err) + + count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(1), count) + + exists, err := mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, resp.Domain) + require.NoError(t, err) + assert.True(t, exists, "service should exist") + + exists, err = mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, "no-such.domain") + require.NoError(t, err) + assert.False(t, exists, "non-existent service should not exist") +} + +func TestMaxExposesPerPeerEnforced(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + ctx := context.Background() + + for i := range maxExposesPerPeer { + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8090 + i, + Protocol: "http", + }) + require.NoError(t, err, "expose %d should succeed", i) + } + + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 9999, + Protocol: "http", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "maximum number of active expose sessions") +} + +func TestReapSkipsRenewedService(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + ctx := context.Background() + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ + Port: 8086, + Protocol: "http", + }) + require.NoError(t, err) + + // Expire the service + expireEphemeralService(t, testStore, testAccountID, resp.Domain) + + // Renew it before the reaper runs + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + require.NoError(t, err) + + // Reaper should skip it because the re-check sees a fresh timestamp + mgr.exposeReaper.reapExpiredExposes(ctx) + + _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + require.NoError(t, err, "renewed service should survive reaping") +} + +// expireEphemeralService backdates meta_last_renewed_at to force expiration. +func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) { + t.Helper() + svc, err := s.GetServiceByDomain(context.Background(), accountID, domain) + require.NoError(t, err) + + expired := time.Now().Add(-2 * exposeTTL) + svc.Meta.LastRenewedAt = &expired + err = s.UpdateService(context.Background(), svc) + require.NoError(t, err) } diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 16a57abb6..b5e643799 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -37,7 +37,7 @@ type Manager struct { permissionsManager permissions.Manager proxyController proxy.Controller clusterDeriver ClusterDeriver - exposeTracker *exposeTracker + exposeReaper *exposeReaper } // NewManager creates a new service manager. @@ -49,13 +49,13 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa proxyController: proxyController, clusterDeriver: clusterDeriver, } - mgr.exposeTracker = &exposeTracker{manager: mgr} + mgr.exposeReaper = &exposeReaper{manager: mgr} return mgr } -// StartExposeReaper delegates to the expose tracker. +// StartExposeReaper starts the background goroutine that reaps expired ephemeral services. func (m *Manager) StartExposeReaper(ctx context.Context) { - m.exposeTracker.StartExposeReaper(ctx) + m.exposeReaper.StartExposeReaper(ctx) } func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { @@ -215,6 +215,52 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, servi }) } +// persistNewEphemeralService creates an ephemeral service inside a single transaction +// that also enforces the duplicate and per-peer limit checks atomically. +// The count and exists queries use FOR UPDATE locking to serialize concurrent creates +// for the same peer, preventing the per-peer limit from being bypassed. +func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error { + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + // Lock the peer row to serialize concurrent creates for the same peer. + // Without this, when no ephemeral rows exist yet, FOR UPDATE on the services + // table returns no rows and acquires no locks, allowing concurrent inserts + // to bypass the per-peer limit. + if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil { + return fmt.Errorf("lock peer row: %w", err) + } + + exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain) + if err != nil { + return fmt.Errorf("check existing expose: %w", err) + } + if exists { + return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain") + } + + count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID) + if err != nil { + return fmt.Errorf("count peer exposes: %w", err) + } + if count >= int64(maxExposesPerPeer) { + return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) + } + + if err := m.checkDomainAvailable(ctx, transaction, accountID, svc.Domain, ""); err != nil { + return err + } + + if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil { + return err + } + + if err := transaction.CreateService(ctx, svc); err != nil { + return fmt.Errorf("create service: %w", err) + } + + return nil + }) +} + func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error { existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain) if err != nil { @@ -412,10 +458,6 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI return err } - if s.Source == service.SourceEphemeral { - m.exposeTracker.UntrackExpose(s.SourcePeer, s.Domain) - } - m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta()) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) @@ -457,9 +499,6 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin oidcCfg := m.proxyController.GetOIDCValidationConfig() for _, svc := range services { - if svc.Source == service.SourceEphemeral { - m.exposeTracker.UntrackExpose(svc.SourcePeer, svc.Domain) - } m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta()) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster) } @@ -681,26 +720,13 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s return nil, err } - now := time.Now() - svc.Meta.LastRenewedAt = &now svc.SourcePeer = peerID - if err := m.persistNewService(ctx, accountID, svc); err != nil { - return nil, err - } + now := time.Now() + svc.Meta.LastRenewedAt = &now - alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, svc.Domain, accountID) - if alreadyTracked { - if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil { - log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", svc.Domain, err) - } - return nil, status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain") - } - if !allowed { - if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil { - log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", svc.Domain, err) - } - return nil, status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) + if err := m.persistNewEphemeralService(ctx, accountID, peerID, svc); err != nil { + return nil, err } meta := addPeerInfoToEventMeta(svc.EventMeta(), peer) @@ -748,26 +774,17 @@ func (m *Manager) buildRandomDomain(name string) (string, error) { return domain, nil } -// RenewServiceFromPeer renews the in-memory TTL tracker for the peer's expose session. -// Returns an error if the expose is not actively tracked. -func (m *Manager) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error { - if !m.exposeTracker.RenewTrackedExpose(peerID, domain) { - return status.Errorf(status.NotFound, "no active expose session for domain %s", domain) - } - return nil +// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service. +func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { + return m.store.RenewEphemeralService(ctx, accountID, peerID, domain) } -// StopServiceFromPeer stops a peer's active expose session by untracking and deleting the service. +// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB. func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil { log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err) return err } - - if !m.exposeTracker.StopTrackedExpose(peerID, domain) { - log.WithContext(ctx).Warnf("expose tracker entry for domain %s already removed; service was deleted", domain) - } - return nil } @@ -848,6 +865,57 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv return nil } +// deleteExpiredPeerService deletes an ephemeral service by ID after re-checking +// that it is still expired under a row lock. This prevents deleting a service +// that was renewed between the batch query and this delete, and ensures only one +// management instance processes the deletion +func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerID, serviceID string) error { + var svc *service.Service + deleted := false + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error + svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) + if err != nil { + return err + } + + if svc.Source != service.SourceEphemeral || svc.SourcePeer != peerID { + return status.Errorf(status.PermissionDenied, "service does not match expected ephemeral owner") + } + + if svc.Meta.LastRenewedAt != nil && time.Since(*svc.Meta.LastRenewedAt) <= exposeTTL { + return nil + } + + if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil { + return fmt.Errorf("delete service: %w", err) + } + deleted = true + + return nil + }) + if err != nil { + return err + } + + if !deleted { + return nil + } + + peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err) + peer = nil + } + + meta := addPeerInfoToEventMeta(svc.EventMeta(), peer) + m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) + m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]any { if peer == nil { return meta diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 99409e235..196eead22 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -720,7 +720,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { domains: []string{"test.netbird.io"}, }, } - mgr.exposeTracker = &exposeTracker{manager: mgr} + mgr.exposeReaper = &exposeReaper{manager: mgr} return mgr, testStore } @@ -1017,36 +1017,38 @@ func TestStopServiceFromPeer(t *testing.T) { }) } -func TestDeleteService_UntracksEphemeralExpose(t *testing.T) { +func TestDeleteService_DeletesEphemeralExpose(t *testing.T) { ctx := context.Background() - mgr, _ := setupIntegrationTest(t) + mgr, testStore := setupIntegrationTest(t) resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ Port: 8080, Protocol: "http", }) require.NoError(t, err) - assert.Equal(t, 1, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be tracked after create") - // Look up the service by domain to get its store ID - svc, err := mgr.store.GetServiceByDomain(ctx, testAccountID, resp.Domain) + count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(1), count, "one ephemeral service should exist after create") + + svc, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) require.NoError(t, err) - // Delete via the API path (user-initiated) err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID) require.NoError(t, err) - assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be untracked after API delete") + count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete") - // A new expose should succeed (not blocked by stale tracking) _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ Port: 9090, Protocol: "http", }) - assert.NoError(t, err, "new expose should succeed after API delete cleared tracking") + assert.NoError(t, err, "new expose should succeed after API delete") } -func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) { +func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) { ctx := context.Background() mgr, _ := setupIntegrationTest(t) @@ -1058,12 +1060,16 @@ func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) { require.NoError(t, err) } - assert.Equal(t, 3, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be tracked") + count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(3), count, "all ephemeral services should exist") - err := mgr.DeleteAllServices(ctx, testAccountID, testUserID) + err = mgr.DeleteAllServices(ctx, testAccountID, testUserID) require.NoError(t, err) - assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be untracked after DeleteAllServices") + count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID) + require.NoError(t, err) + assert.Equal(t, int64(0), count, "all ephemeral services should be deleted after DeleteAllServices") } func TestRenewServiceFromPeer(t *testing.T) { diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index 46ae185d6..ee4a91e1f 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -133,8 +133,8 @@ type Service struct { Meta Meta `gorm:"embedded;embeddedPrefix:meta_"` SessionPrivateKey string `gorm:"column:session_private_key"` SessionPublicKey string `gorm:"column:session_public_key"` - Source string `gorm:"default:'permanent'"` - SourcePeer string + Source string `gorm:"default:'permanent';index:idx_service_source_peer"` + SourcePeer string `gorm:"index:idx_service_source_peer"` } func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 41c53980b..8f147d915 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -5040,6 +5040,99 @@ func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingS return serviceList, nil } +// RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service. +func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error { + result := s.db.Model(&rpservice.Service{}). + Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral). + Update("meta_last_renewed_at", time.Now()) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error) + return status.Errorf(status.Internal, "renew ephemeral service") + } + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "no active expose session for domain %s", domain) + } + return nil +} + +// GetExpiredEphemeralServices returns ephemeral services whose last renewal exceeds the given TTL. +// Only the fields needed for reaping are selected. The limit parameter caps the batch size to +// avoid loading too many rows in a single tick. Rows with empty source_peer are excluded to +// skip malformed legacy data. +func (s *SqlStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) { + cutoff := time.Now().Add(-ttl) + var services []*rpservice.Service + result := s.db. + Select("id", "account_id", "source_peer", "domain"). + Where("source = ? AND source_peer <> '' AND meta_last_renewed_at < ?", rpservice.SourceEphemeral, cutoff). + Limit(limit). + Find(&services) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get expired ephemeral services: %v", result.Error) + return nil, status.Errorf(status.Internal, "get expired ephemeral services") + } + return services, nil +} + +// CountEphemeralServicesByPeer returns the count of ephemeral services for a specific peer. +// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations. +// The locking is applied via a row-level SELECT ... FOR UPDATE (not on the aggregate) to +// stay compatible with Postgres, which disallows FOR UPDATE on COUNT(*). +func (s *SqlStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) { + if lockStrength == LockingStrengthNone { + var count int64 + result := s.db.Model(&rpservice.Service{}). + Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral). + Count(&count) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error) + return 0, status.Errorf(status.Internal, "count ephemeral services") + } + return count, nil + } + + var ids []string + result := s.db.Model(&rpservice.Service{}). + Clauses(clause.Locking{Strength: string(lockStrength)}). + Select("id"). + Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral). + Pluck("id", &ids) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error) + return 0, status.Errorf(status.Internal, "count ephemeral services") + } + return int64(len(ids)), nil +} + +// EphemeralServiceExists checks if an ephemeral service exists for the given peer and domain. +// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations. +func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) { + if lockStrength == LockingStrengthNone { + var count int64 + result := s.db.Model(&rpservice.Service{}). + Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral). + Count(&count) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error) + return false, status.Errorf(status.Internal, "check ephemeral service existence") + } + return count > 0, nil + } + + var id string + result := s.db.Model(&rpservice.Service{}). + Clauses(clause.Locking{Strength: string(lockStrength)}). + Select("id"). + Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral). + Limit(1). + Pluck("id", &id) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error) + return false, status.Errorf(status.Internal, "check ephemeral service existence") + } + return id != "", nil +} + func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) { tx := s.db diff --git a/management/server/store/store.go b/management/server/store/store.go index 941aca08a..5123cde72 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -261,6 +261,11 @@ type Store interface { GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) + RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error + GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) + CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) + EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) + GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) ListFreeDomains(ctx context.Context, accountID string) ([]string, error) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 9e11f85fb..414872fbb 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -208,6 +208,21 @@ func (mr *MockStoreMockRecorder) CountAccountsByPrivateDomain(ctx, domain interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountsByPrivateDomain", reflect.TypeOf((*MockStore)(nil).CountAccountsByPrivateDomain), ctx, domain) } +// CountEphemeralServicesByPeer mocks base method. +func (m *MockStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountEphemeralServicesByPeer", ctx, lockStrength, accountID, peerID) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountEphemeralServicesByPeer indicates an expected call of CountEphemeralServicesByPeer. +func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength, accountID, peerID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID) +} + // CreateAccessLog mocks base method. func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error { m.ctrl.T.Helper() @@ -686,6 +701,21 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID) } +// EphemeralServiceExists mocks base method. +func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EphemeralServiceExists", ctx, lockStrength, accountID, peerID, domain) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EphemeralServiceExists indicates an expected call of EphemeralServiceExists. +func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accountID, peerID, domain interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain) +} + // ExecuteInTransaction mocks base method. func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error { m.ctrl.T.Helper() @@ -1362,6 +1392,21 @@ func (mr *MockStoreMockRecorder) GetDNSRecordByID(ctx, lockStrength, accountID, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID) } +// GetExpiredEphemeralServices mocks base method. +func (m *MockStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*service.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetExpiredEphemeralServices", ctx, ttl, limit) + ret0, _ := ret[0].([]*service.Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetExpiredEphemeralServices indicates an expected call of GetExpiredEphemeralServices. +func (mr *MockStoreMockRecorder) GetExpiredEphemeralServices(ctx, ttl, limit interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiredEphemeralServices", reflect.TypeOf((*MockStore)(nil).GetExpiredEphemeralServices), ctx, ttl, limit) +} + // GetGroupByID mocks base method. func (m *MockStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types2.Group, error) { m.ctrl.T.Helper() @@ -2401,6 +2446,20 @@ func (mr *MockStoreMockRecorder) RemoveResourceFromGroup(ctx, accountId, groupID return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResourceFromGroup", reflect.TypeOf((*MockStore)(nil).RemoveResourceFromGroup), ctx, accountId, groupID, resourceID) } +// RenewEphemeralService mocks base method. +func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, domain) + ret0, _ := ret[0].(error) + return ret0 +} + +// RenewEphemeralService indicates an expected call of RenewEphemeralService. +func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, domain interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, domain) +} + // RevokeProxyAccessToken mocks base method. func (m *MockStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error { m.ctrl.T.Helper()