[management] Add explicit target delete on service removal (#5420)

This commit is contained in:
Pascal Fischer
2026-03-02 18:25:44 +01:00
committed by GitHub
parent bbe5ae2145
commit 82da606886
7 changed files with 1932 additions and 36 deletions

View File

@@ -4,12 +4,12 @@ import (
"context"
"fmt"
"math/rand/v2"
"slices"
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus"
"slices"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
@@ -410,12 +410,15 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
var service *reverseproxy.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}

View File

@@ -13,11 +13,14 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -1112,3 +1115,67 @@ func TestGetGroupIDsFromNames(t *testing.T) {
assert.Contains(t, err.Error(), "no group names provided")
})
}
func TestDeleteService_DeletesTargets(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
userID := "test-user"
sqlStore, err := store.NewStore(ctx, types.SqliteStoreEngine, t.TempDir(), nil, false)
require.NoError(t, err)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockPerms := permissions.NewMockManager(ctrl)
mockAcct := account.NewMockManager(ctrl)
mockGRPC := &nbgrpc.ProxyServiceServer{}
mgr := &managerImpl{
store: sqlStore,
permissionsManager: mockPerms,
accountManager: mockAcct,
proxyGRPCServer: mockGRPC,
}
service := &reverseproxy.Service{
ID: "service-1",
AccountID: accountID,
Domain: "test.example.com",
ProxyCluster: "cluster1",
Enabled: true,
Targets: []*reverseproxy.Target{
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-1"},
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-2"},
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-3"},
},
}
err = sqlStore.CreateService(ctx, service)
require.NoError(t, err)
retrievedService, err := sqlStore.GetServiceByID(ctx, store.LockingStrengthNone, accountID, service.ID)
require.NoError(t, err)
require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion")
mockPerms.EXPECT().
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
Return(true, nil)
mockAcct.EXPECT().
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
mockAcct.EXPECT().
UpdateAccountPeers(ctx, accountID)
err = mgr.DeleteService(ctx, accountID, userID, service.ID)
require.NoError(t, err)
_, err = sqlStore.GetServiceByID(ctx, store.LockingStrengthNone, accountID, service.ID)
require.Error(t, err)
s, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.NotFound, s.Type())
targets, err := sqlStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, accountID, service.ID)
require.NoError(t, err)
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
}

View File

@@ -1,5 +1,7 @@
package account
//go:generate go run github.com/golang/mock/mockgen -package account -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"net"
@@ -61,11 +63,11 @@ type Manager interface {
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)

File diff suppressed because it is too large Load Diff

View File

@@ -4895,6 +4895,46 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
return nil
}
func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error {
result := s.db.Delete(&reverseproxy.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete target from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete target from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "target not found for service %s", serviceID)
}
return nil
}
func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error {
result := s.db.Delete(&reverseproxy.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete targets from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete targets from store")
}
return nil
}
// GetTargetsByServiceID retrieves all targets for a given service
func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*reverseproxy.Target, error) {
var targets []*reverseproxy.Target
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
result := tx.Where("account_id = ? AND service_id = ?", accountID, serviceID).Find(&targets)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get targets from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get targets from store")
}
return targets, nil
}
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {

View File

@@ -272,6 +272,9 @@ type Store interface {
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*reverseproxy.Target, error)
DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
// GetCustomDomainsCounts returns the total and validated custom domain counts.
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)

View File

@@ -559,6 +559,20 @@ func (mr *MockStoreMockRecorder) DeleteService(ctx, accountID, serviceID interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockStore)(nil).DeleteService), ctx, accountID, serviceID)
}
// DeleteServiceTargets mocks base method.
func (m *MockStore) DeleteServiceTargets(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteServiceTargets", ctx, accountID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteServiceTargets indicates an expected call of DeleteServiceTargets.
func (mr *MockStoreMockRecorder) DeleteServiceTargets(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteServiceTargets", reflect.TypeOf((*MockStore)(nil).DeleteServiceTargets), ctx, accountID, serviceID)
}
// DeleteSetupKey mocks base method.
func (m *MockStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
m.ctrl.T.Helper()
@@ -573,6 +587,20 @@ func (mr *MockStoreMockRecorder) DeleteSetupKey(ctx, accountID, keyID interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSetupKey", reflect.TypeOf((*MockStore)(nil).DeleteSetupKey), ctx, accountID, keyID)
}
// DeleteTarget mocks base method.
func (m *MockStore) DeleteTarget(ctx context.Context, accountID, serviceID string, targetID uint) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteTarget", ctx, accountID, serviceID, targetID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteTarget indicates an expected call of DeleteTarget.
func (mr *MockStoreMockRecorder) DeleteTarget(ctx, accountID, serviceID, targetID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTarget", reflect.TypeOf((*MockStore)(nil).DeleteTarget), ctx, accountID, serviceID, targetID)
}
// DeleteTokenID2UserIDIndex mocks base method.
func (m *MockStore) DeleteTokenID2UserIDIndex(tokenID string) error {
m.ctrl.T.Helper()
@@ -1109,21 +1137,6 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID)
}
// GetServicesByAccountID mocks base method.
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
}
// GetAccountSettings mocks base method.
func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) {
m.ctrl.T.Helper()
@@ -1288,6 +1301,22 @@ func (mr *MockStoreMockRecorder) GetCustomDomain(ctx, accountID, domainID interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomain", reflect.TypeOf((*MockStore)(nil).GetCustomDomain), ctx, accountID, domainID)
}
// GetCustomDomainsCounts mocks base method.
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(int64)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
}
// GetDNSRecordByID mocks base method.
func (m *MockStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
m.ctrl.T.Helper()
@@ -1872,22 +1901,6 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTargetByTargetID", reflect.TypeOf((*MockStore)(nil).GetServiceTargetByTargetID), ctx, lockStrength, accountID, targetID)
}
// GetCustomDomainsCounts mocks base method.
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(int64)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
}
// GetServices mocks base method.
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
m.ctrl.T.Helper()
@@ -1903,6 +1916,21 @@ func (mr *MockStoreMockRecorder) GetServices(ctx, lockStrength interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServices", reflect.TypeOf((*MockStore)(nil).GetServices), ctx, lockStrength)
}
// GetServicesByAccountID mocks base method.
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
}
// GetSetupKeyByID mocks base method.
func (m *MockStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types2.SetupKey, error) {
m.ctrl.T.Helper()
@@ -1962,6 +1990,21 @@ func (mr *MockStoreMockRecorder) GetTakenIPs(ctx, lockStrength, accountId interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTakenIPs", reflect.TypeOf((*MockStore)(nil).GetTakenIPs), ctx, lockStrength, accountId)
}
// GetTargetsByServiceID mocks base method.
func (m *MockStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) ([]*reverseproxy.Target, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTargetsByServiceID", ctx, lockStrength, accountID, serviceID)
ret0, _ := ret[0].([]*reverseproxy.Target)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTargetsByServiceID indicates an expected call of GetTargetsByServiceID.
func (mr *MockStoreMockRecorder) GetTargetsByServiceID(ctx, lockStrength, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTargetsByServiceID", reflect.TypeOf((*MockStore)(nil).GetTargetsByServiceID), ctx, lockStrength, accountID, serviceID)
}
// GetTokenIDByHashedToken mocks base method.
func (m *MockStore) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) {
m.ctrl.T.Helper()