diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 65177bf5d..2251f5084 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -14,6 +14,8 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" + resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" @@ -636,18 +638,12 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco for _, target := range targets { switch target.TargetType { case service.TargetTypePeer: - if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { - if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { - return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId) - } - return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) + if err := validatePeerTarget(ctx, transaction, accountID, target); err != nil { + return err } case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain: - if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { - if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { - return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) - } - return fmt.Errorf("look up resource target %q: %w", target.TargetId, err) + if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil { + return err } default: return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId) @@ -656,6 +652,39 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco return nil } +func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error { + if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId) + } + return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) + } + return nil +} + +func validateResourceTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error { + resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId) + if err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) + } + return fmt.Errorf("look up resource target %q: %w", target.TargetId, err) + } + return validateResourceTargetType(target, resource) +} + +// validateResourceTargetType checks that target_type matches the actual network resource type. +func validateResourceTargetType(target *service.Target, resource *resourcetypes.NetworkResource) error { + expected := resourcetypes.NetworkResourceType(target.TargetType) + if resource.Type != expected { + return status.Errorf(status.InvalidArgument, + "target %q has target_type %q but resource is of type %q", + target.TargetId, target.TargetType, resource.Type, + ) + } + return nil +} + func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) if err != nil { diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index d23c91017..0c34f81a2 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/mock_server" + resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -1214,3 +1215,60 @@ func TestValidateProtocolChange(t *testing.T) { }) } } + +func TestValidateTargetReferences_ResourceTypeMismatch(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + mockStore := store.NewMockStore(ctrl) + accountID := "test-account" + + tests := []struct { + name string + targetType rpservice.TargetType + resourceType resourcetypes.NetworkResourceType + wantErr bool + }{ + {"host matches host", rpservice.TargetTypeHost, resourcetypes.Host, false}, + {"domain matches domain", rpservice.TargetTypeDomain, resourcetypes.Domain, false}, + {"subnet matches subnet", rpservice.TargetTypeSubnet, resourcetypes.Subnet, false}, + {"host but resource is domain", rpservice.TargetTypeHost, resourcetypes.Domain, true}, + {"domain but resource is host", rpservice.TargetTypeDomain, resourcetypes.Host, true}, + {"host but resource is subnet", rpservice.TargetTypeHost, resourcetypes.Subnet, true}, + {"subnet but resource is domain", rpservice.TargetTypeSubnet, resourcetypes.Domain, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockStore.EXPECT(). + GetNetworkResourceByID(gomock.Any(), store.LockingStrengthShare, accountID, "resource-1"). + Return(&resourcetypes.NetworkResource{Type: tt.resourceType}, nil) + + targets := []*rpservice.Target{ + {TargetId: "resource-1", TargetType: tt.targetType, Host: "10.0.0.1"}, + } + err := validateTargetReferences(ctx, mockStore, accountID, targets) + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "target_type") + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateTargetReferences_PeerValid(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + mockStore := store.NewMockStore(ctrl) + accountID := "test-account" + + mockStore.EXPECT(). + GetPeerByID(gomock.Any(), store.LockingStrengthShare, accountID, "peer-1"). + Return(&nbpeer.Peer{}, nil) + + targets := []*rpservice.Target{ + {TargetId: "peer-1", TargetType: rpservice.TargetTypePeer}, + } + require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets)) +} diff --git a/management/internals/modules/reverseproxy/service/service.go b/management/internals/modules/reverseproxy/service/service.go index 6c7c80806..c00d49421 100644 --- a/management/internals/modules/reverseproxy/service/service.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -790,7 +790,7 @@ func (s *Service) validateL4Target(target *Target) error { return errors.New("target_id is required for L4 services") } switch target.TargetType { - case TargetTypePeer, TargetTypeHost: + case TargetTypePeer, TargetTypeHost, TargetTypeDomain: // OK case TargetTypeSubnet: if target.Host == "" { diff --git a/management/internals/modules/reverseproxy/service/service_test.go b/management/internals/modules/reverseproxy/service/service_test.go index 9daf729fe..3fe07b1d0 100644 --- a/management/internals/modules/reverseproxy/service/service_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -847,6 +847,32 @@ func TestValidate_TLSSubnetValid(t *testing.T) { require.NoError(t, rp.Validate()) } +func TestValidate_L4DomainTargetValid(t *testing.T) { + modes := []struct { + mode string + port uint16 + proto string + }{ + {"tcp", 5432, "tcp"}, + {"tls", 443, "tcp"}, + {"udp", 5432, "udp"}, + } + for _, m := range modes { + t.Run(m.mode, func(t *testing.T) { + rp := &Service{ + Name: m.mode + "-domain", + Mode: m.mode, + Domain: "cluster.test", + ListenPort: m.port, + Targets: []*Target{ + {TargetId: "resource-1", TargetType: TargetTypeDomain, Protocol: m.proto, Port: m.port, Enabled: true}, + }, + } + require.NoError(t, rp.Validate()) + }) + } +} + func TestValidate_HTTPProxyProtocolRejected(t *testing.T) { rp := validProxy() rp.Targets[0].ProxyProtocol = true