mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
[management] Add explicit target delete on service removal (#5420)
This commit is contained in:
@@ -4,12 +4,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"slices"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
@@ -410,12 +410,15 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
|
||||
|
||||
var service *reverseproxy.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("failed to delete targets: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
|
||||
@@ -13,11 +13,14 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -1112,3 +1115,67 @@ func TestGetGroupIDsFromNames(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "no group names provided")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
userID := "test-user"
|
||||
|
||||
sqlStore, err := store.NewStore(ctx, types.SqliteStoreEngine, t.TempDir(), nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockPerms := permissions.NewMockManager(ctrl)
|
||||
mockAcct := account.NewMockManager(ctrl)
|
||||
mockGRPC := &nbgrpc.ProxyServiceServer{}
|
||||
|
||||
mgr := &managerImpl{
|
||||
store: sqlStore,
|
||||
permissionsManager: mockPerms,
|
||||
accountManager: mockAcct,
|
||||
proxyGRPCServer: mockGRPC,
|
||||
}
|
||||
|
||||
service := &reverseproxy.Service{
|
||||
ID: "service-1",
|
||||
AccountID: accountID,
|
||||
Domain: "test.example.com",
|
||||
ProxyCluster: "cluster1",
|
||||
Enabled: true,
|
||||
Targets: []*reverseproxy.Target{
|
||||
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-1"},
|
||||
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-2"},
|
||||
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-3"},
|
||||
},
|
||||
}
|
||||
|
||||
err = sqlStore.CreateService(ctx, service)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrievedService, err := sqlStore.GetServiceByID(ctx, store.LockingStrengthNone, accountID, service.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion")
|
||||
|
||||
mockPerms.EXPECT().
|
||||
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
|
||||
Return(true, nil)
|
||||
mockAcct.EXPECT().
|
||||
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
|
||||
mockAcct.EXPECT().
|
||||
UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
err = mgr.DeleteService(ctx, accountID, userID, service.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sqlStore.GetServiceByID(ctx, store.LockingStrengthNone, accountID, service.ID)
|
||||
require.Error(t, err)
|
||||
s, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.NotFound, s.Type())
|
||||
|
||||
targets, err := sqlStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, accountID, service.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package account
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package account -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
@@ -61,11 +63,11 @@ type Manager interface {
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
|
||||
1738
management/server/account/manager_mock.go
Normal file
1738
management/server/account/manager_mock.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4895,6 +4895,46 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error {
|
||||
result := s.db.Delete(&reverseproxy.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete target from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete target from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "target not found for service %s", serviceID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error {
|
||||
result := s.db.Delete(&reverseproxy.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete targets from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete targets from store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTargetsByServiceID retrieves all targets for a given service
|
||||
func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*reverseproxy.Target, error) {
|
||||
var targets []*reverseproxy.Target
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
result := tx.Where("account_id = ? AND service_id = ?", accountID, serviceID).Find(&targets)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get targets from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get targets from store")
|
||||
}
|
||||
|
||||
return targets, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
|
||||
@@ -272,6 +272,9 @@ type Store interface {
|
||||
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
|
||||
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
|
||||
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
|
||||
GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*reverseproxy.Target, error)
|
||||
DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error
|
||||
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
||||
|
||||
// GetCustomDomainsCounts returns the total and validated custom domain counts.
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
|
||||
@@ -559,6 +559,20 @@ func (mr *MockStoreMockRecorder) DeleteService(ctx, accountID, serviceID interfa
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockStore)(nil).DeleteService), ctx, accountID, serviceID)
|
||||
}
|
||||
|
||||
// DeleteServiceTargets mocks base method.
|
||||
func (m *MockStore) DeleteServiceTargets(ctx context.Context, accountID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteServiceTargets", ctx, accountID, serviceID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteServiceTargets indicates an expected call of DeleteServiceTargets.
|
||||
func (mr *MockStoreMockRecorder) DeleteServiceTargets(ctx, accountID, serviceID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteServiceTargets", reflect.TypeOf((*MockStore)(nil).DeleteServiceTargets), ctx, accountID, serviceID)
|
||||
}
|
||||
|
||||
// DeleteSetupKey mocks base method.
|
||||
func (m *MockStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -573,6 +587,20 @@ func (mr *MockStoreMockRecorder) DeleteSetupKey(ctx, accountID, keyID interface{
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSetupKey", reflect.TypeOf((*MockStore)(nil).DeleteSetupKey), ctx, accountID, keyID)
|
||||
}
|
||||
|
||||
// DeleteTarget mocks base method.
|
||||
func (m *MockStore) DeleteTarget(ctx context.Context, accountID, serviceID string, targetID uint) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteTarget", ctx, accountID, serviceID, targetID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteTarget indicates an expected call of DeleteTarget.
|
||||
func (mr *MockStoreMockRecorder) DeleteTarget(ctx, accountID, serviceID, targetID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTarget", reflect.TypeOf((*MockStore)(nil).DeleteTarget), ctx, accountID, serviceID, targetID)
|
||||
}
|
||||
|
||||
// DeleteTokenID2UserIDIndex mocks base method.
|
||||
func (m *MockStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1109,21 +1137,6 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetServicesByAccountID mocks base method.
|
||||
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAccountSettings mocks base method.
|
||||
func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1288,6 +1301,22 @@ func (mr *MockStoreMockRecorder) GetCustomDomain(ctx, accountID, domainID interf
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomain", reflect.TypeOf((*MockStore)(nil).GetCustomDomain), ctx, accountID, domainID)
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts mocks base method.
|
||||
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(int64)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
|
||||
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
|
||||
}
|
||||
|
||||
// GetDNSRecordByID mocks base method.
|
||||
func (m *MockStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1872,22 +1901,6 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTargetByTargetID", reflect.TypeOf((*MockStore)(nil).GetServiceTargetByTargetID), ctx, lockStrength, accountID, targetID)
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts mocks base method.
|
||||
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(int64)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
|
||||
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
|
||||
}
|
||||
|
||||
// GetServices mocks base method.
|
||||
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1903,6 +1916,21 @@ func (mr *MockStoreMockRecorder) GetServices(ctx, lockStrength interface{}) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServices", reflect.TypeOf((*MockStore)(nil).GetServices), ctx, lockStrength)
|
||||
}
|
||||
|
||||
// GetServicesByAccountID mocks base method.
|
||||
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetSetupKeyByID mocks base method.
|
||||
func (m *MockStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types2.SetupKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1962,6 +1990,21 @@ func (mr *MockStoreMockRecorder) GetTakenIPs(ctx, lockStrength, accountId interf
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTakenIPs", reflect.TypeOf((*MockStore)(nil).GetTakenIPs), ctx, lockStrength, accountId)
|
||||
}
|
||||
|
||||
// GetTargetsByServiceID mocks base method.
|
||||
func (m *MockStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) ([]*reverseproxy.Target, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTargetsByServiceID", ctx, lockStrength, accountID, serviceID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Target)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTargetsByServiceID indicates an expected call of GetTargetsByServiceID.
|
||||
func (mr *MockStoreMockRecorder) GetTargetsByServiceID(ctx, lockStrength, accountID, serviceID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTargetsByServiceID", reflect.TypeOf((*MockStore)(nil).GetTargetsByServiceID), ctx, lockStrength, accountID, serviceID)
|
||||
}
|
||||
|
||||
// GetTokenIDByHashedToken mocks base method.
|
||||
func (m *MockStore) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Reference in New Issue
Block a user