Compare commits

...

32 Commits

Author SHA1 Message Date
pascal
a4b55af99c add additional integration tests 2026-04-20 17:24:04 +02:00
pascal
470307079b update account isolation 2026-04-20 16:33:24 +02:00
pascal
0b04c0d03b leave todo comments 2026-04-20 15:03:31 +02:00
pascal
bed8d89d9f account for peers permission on the groups endpoint 2026-04-20 14:00:57 +02:00
pascal
b65a8bcb9c update sql test data files and auth wrappers 2026-04-20 13:44:38 +02:00
pascal
a53c38a6ed add sql test files 2026-04-17 17:09:38 +02:00
pascal
8a41117403 fix version test 2026-04-17 16:51:08 +02:00
pascal
e0063731f2 update tests 2026-04-17 16:44:57 +02:00
pascal
a572d8819f fix user handler test 2026-04-17 14:59:19 +02:00
pascal
d139a4fc42 fix client cmd test 2026-04-17 14:34:39 +02:00
pascal
dbe533327f fix client server test 2026-04-17 13:23:26 +02:00
pascal
4835909a04 fix client test 2026-04-17 13:16:33 +02:00
pascal
fb6447e714 fix engine test 2026-04-16 18:44:40 +02:00
pascal
16d6cc1265 update management integrations 2026-04-16 18:36:55 +02:00
pascal
53e47da7bd fix tests 2026-04-16 18:24:47 +02:00
pascal
ce522ea69b fix mock 2026-04-16 18:16:44 +02:00
pascal
5be80e976f fix mock 2026-04-16 17:55:14 +02:00
pascal
dd9539a9bf fix gomod 2026-04-16 17:48:11 +02:00
pascal
8530f9b8fc Merge branch 'main' into refactor/permissions-manager
# Conflicts:
#	go.mod
#	go.sum
#	management/server/account_test.go
#	management/server/http/testing/testing_tools/channel/channel.go
2026-04-16 17:47:30 +02:00
pascal
9406de9610 add exception handler and add exception for users 2026-04-16 17:42:59 +02:00
pascal
9e385eb540 add optional auth error handlers 2026-04-16 16:24:55 +02:00
pascal
e46ea895c1 Merge branch 'main' into refactor/permissions-manager
# Conflicts:
#	management/server/account_test.go
#	management/server/http/handlers/networks/routers_handler.go
2026-04-16 14:52:46 +02:00
pascal
31d901c4b0 update role permissions for admins 2026-04-08 12:15:01 +02:00
pascal
20e6dff507 Merge branch 'main' into refactor/permissions-manager 2026-04-07 17:35:39 +02:00
pascal
f5c8a6fe1a update integrations 2026-03-27 16:46:08 +01:00
pascal
beee14b9bf fix merge conflicts 2026-03-27 14:47:06 +01:00
pascal
3013c98ab5 Merge branch 'main' into refactor/permissions-manager
# Conflicts:
#	management/internals/modules/reverseproxy/service/manager/api.go
#	management/server/http/testing/testing_tools/channel/channel.go
2026-03-27 14:37:29 +01:00
pascal
3741eb46dd Merge remote-tracking branch 'origin/main' into refactor/permissions-manager
# Conflicts:
#	management/internals/modules/reverseproxy/domain/manager/manager.go
#	management/internals/modules/reverseproxy/service/manager/api.go
#	management/internals/server/modules.go
#	management/server/http/testing/testing_tools/channel/channel.go
2026-03-17 12:38:08 +01:00
pascal
d85ee0b5a2 remove old permissions management 2026-03-06 18:19:33 +01:00
pascal
da4a0eb68a Merge branch 'main' into refactor/permissions-manager 2026-03-06 17:06:18 +01:00
pascal
b0ce0048b4 Merge branch 'refs/heads/main' into refactor/permissions-manager
# Conflicts:
#	management/internals/modules/reverseproxy/service/manager/api.go
#	management/server/http/handler.go
2026-03-06 11:37:53 +01:00
pascal
32730af33f use api wrapper for permissions management 2026-02-24 00:39:54 +01:00
145 changed files with 4143 additions and 3647 deletions

View File

@@ -13,6 +13,8 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -29,7 +31,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -97,7 +98,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersmanager := peers.NewManager(store, permissionsManagerMock)
peersmanager := peers.NewManager(store)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersmanager)

View File

@@ -25,6 +25,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/management-integrations/integrations"
@@ -57,7 +58,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -1632,7 +1632,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
peersManager := peers.NewManager(store)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)

View File

@@ -15,6 +15,8 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
@@ -38,7 +40,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -305,7 +306,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store, permissionsManagerMock)
peersManager := peers.NewManager(store)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager)

2
go.mod
View File

@@ -71,7 +71,7 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0

4
go.sum
View File

@@ -453,8 +453,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42/go.mod h1:n47r67ZSPgwSmT/Z1o48JjZQW9YJ6m/6Bd/uAXkL3Pg=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34 h1:g74mB64wnjCagzE1spKgPfTI/ont1SdSL3uX5bOecgM=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416163311-004852ffaf34/go.mod h1:lCOq5d1i19AQjEEW2d7aNK0Nn0KC0MKyfMz/PLwVBFg=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -16,9 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -38,17 +35,15 @@ type Manager interface {
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
integratedPeerValidator integrated_validator.IntegratedValidator
accountManager account.Manager
networkMapController network_map.Controller
}
func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
func NewManager(store store.Store) Manager {
return &managerImpl{
store: store,
permissionsManager: permissionsManager,
store: store,
}
}
@@ -65,28 +60,10 @@ func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
}
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
}

View File

@@ -4,20 +4,29 @@ package permissions
import (
"context"
"net/http"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/permissions/roles"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/permissions/roles"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// AuthErrorHandler is called when an auth error occurs during permission validation.
// If it returns true, the error is considered handled and the default error response is skipped.
type AuthErrorHandler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool
type Manager interface {
WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth), authErrHandler ...AuthErrorHandler) http.HandlerFunc
ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error)
ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
@@ -36,6 +45,51 @@ func NewManager(store store.Store) Manager {
}
}
// WithPermission wraps an HTTP handler with permission checking logic.
// An optional AuthErrorHandler can be provided to intercept auth errors before the default response is written.
func (m *managerImpl) WithPermission(
module modules.Module,
operation operations.Operation,
handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth),
authErrHandler ...AuthErrorHandler,
) http.HandlerFunc {
var onAuthErr AuthErrorHandler
if len(authErrHandler) > 0 {
onAuthErr = authErrHandler[0]
}
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get user auth from context: %v", err)
util.WriteError(r.Context(), err, w)
return
}
allowed, err := m.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, module, operation)
if err != nil {
if onAuthErr != nil && onAuthErr(w, r, &userAuth, err) {
return
}
log.WithContext(r.Context()).Errorf("failed to validate permissions for user %s on account %s: %v", userAuth.UserId, userAuth.AccountId, err)
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
return
}
if !allowed {
permErr := status.NewPermissionDeniedError()
if onAuthErr != nil && onAuthErr(w, r, &userAuth, permErr) {
return
}
log.WithContext(r.Context()).Tracef("user %s on account %s is not allowed to %s in %s", userAuth.UserId, userAuth.AccountId, operation, module)
util.WriteError(r.Context(), permErr, w)
return
}
handlerFunc(w, r, &userAuth)
}
}
func (m *managerImpl) ValidateUserPermissions(
ctx context.Context,
accountID string,
@@ -127,3 +181,17 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR
func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
// no-op
}
// WrapHandler wraps a handler that expects UserAuth with context extraction.
// Unlike WithPermission, it does not perform any permission checks.
func WrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get user auth from context: %v", err)
util.WriteError(r.Context(), err, w)
return
}
h(w, r, &userAuth)
}
}

View File

@@ -5,15 +5,18 @@
package permissions
import (
context "context"
reflect "reflect"
"context"
"net/http"
"reflect"
gomock "github.com/golang/mock/gomock"
account "github.com/netbirdio/netbird/management/server/account"
modules "github.com/netbirdio/netbird/management/server/permissions/modules"
operations "github.com/netbirdio/netbird/management/server/permissions/operations"
roles "github.com/netbirdio/netbird/management/server/permissions/roles"
types "github.com/netbirdio/netbird/management/server/types"
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/permissions/roles"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
)
// MockManager is a mock of Manager interface.
@@ -108,3 +111,22 @@ func (mr *MockManagerMockRecorder) ValidateUserPermissions(ctx, accountID, userI
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUserPermissions", reflect.TypeOf((*MockManager)(nil).ValidateUserPermissions), ctx, accountID, userID, module, operation)
}
// WithPermission mocks base method.
func (m *MockManager) WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(http.ResponseWriter, *http.Request, *auth.UserAuth), authErrHandler ...AuthErrorHandler) http.HandlerFunc {
m.ctrl.T.Helper()
varargs := []interface{}{module, operation, handlerFunc}
for _, a := range authErrHandler {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "WithPermission", varargs...)
ret0, _ := ret[0].(http.HandlerFunc)
return ret0
}
// WithPermission indicates an expected call of WithPermission.
func (mr *MockManagerMockRecorder) WithPermission(module, operation, handlerFunc interface{}, authErrHandler ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{module, operation, handlerFunc}, authErrHandler...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithPermission", reflect.TypeOf((*MockManager)(nil).WithPermission), varargs...)
}

View File

@@ -1,8 +1,8 @@
package roles
import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
)
@@ -18,7 +18,7 @@ var Admin = RolePermissions{
modules.Accounts: {
operations.Read: true,
operations.Create: false,
operations.Update: false,
operations.Update: true,
operations.Delete: false,
},
},

View File

@@ -1,7 +1,7 @@
package roles
import (
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
)

View File

@@ -1,8 +1,8 @@
package roles
import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
)

View File

@@ -1,7 +1,7 @@
package roles
import (
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
)

View File

@@ -1,8 +1,8 @@
package roles
import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
)

View File

@@ -1,7 +1,7 @@
package roles
import (
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
)

View File

@@ -5,8 +5,11 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -15,21 +18,15 @@ type handler struct {
manager accesslogs.Manager
}
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager, permissionsManager permissions.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
router.HandleFunc("/events/proxy", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAccessLogs)).Methods("GET", "OPTIONS")
}
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var filter accesslogs.AccessLogFilter
filter.ParseFromRequest(r)

View File

@@ -9,25 +9,19 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
geo geolocation.Geolocation
cleanupCancel context.CancelFunc
store store.Store
geo geolocation.Geolocation
cleanupCancel context.CancelFunc
}
func NewManager(store store.Store, permissionsManager permissions.Manager, geo geolocation.Geolocation) accesslogs.Manager {
func NewManager(store store.Store, geo geolocation.Geolocation) accesslogs.Manager {
return &managerImpl{
store: store,
permissionsManager: permissionsManager,
geo: geo,
store: store,
geo: geo,
}
}
@@ -63,14 +57,6 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, 0, status.NewPermissionValidationError(err)
}
if !ok {
return nil, 0, status.NewPermissionDeniedError()
}
if err := m.resolveUserFilters(ctx, accountID, filter); err != nil {
log.WithContext(ctx).Warnf("failed to resolve user filters: %v", err)
}

View File

@@ -6,8 +6,11 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -17,15 +20,15 @@ type handler struct {
manager Manager
}
func RegisterEndpoints(router *mux.Router, manager Manager) {
func RegisterEndpoints(router *mux.Router, manager Manager, permissionsManager permissions.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllDomains)).Methods("GET", "OPTIONS")
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Create, h.createCustomDomain)).Methods("POST", "OPTIONS")
router.HandleFunc("/domains/{domainId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteCustomDomain)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/domains/{domainId}/validate", permissionsManager.WithPermission(modules.Services, operations.Create, h.triggerCustomDomainValidation)).Methods("GET", "OPTIONS") // TODO: this should be a POST
}
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
@@ -56,13 +59,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
return resp
}
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -77,13 +74,7 @@ func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, ret)
}
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.PostApiReverseProxiesDomainsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -99,13 +90,7 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
}
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
domainID := mux.Vars(r)["domainId"]
if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
@@ -120,13 +105,7 @@ func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
domainID := mux.Vars(r)["domainId"]
if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)

View File

@@ -11,11 +11,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
type store interface {
@@ -37,32 +33,22 @@ type proxyManager interface {
}
type Manager struct {
store store
validator domain.Validator
proxyManager proxyManager
permissionsManager permissions.Manager
accountManager account.Manager
store store
validator domain.Validator
proxyManager proxyManager
accountManager account.Manager
}
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
func NewManager(store store, proxyMgr proxyManager, accountManager account.Manager) Manager {
return Manager{
store: store,
proxyManager: proxyMgr,
validator: domain.Validator{Resolver: net.DefaultResolver},
permissionsManager: permissionsManager,
accountManager: accountManager,
store: store,
proxyManager: proxyMgr,
validator: domain.Validator{Resolver: net.DefaultResolver},
accountManager: accountManager,
}
}
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
domains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("list custom domains: %w", err)
@@ -118,14 +104,6 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
}
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
@@ -159,14 +137,6 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
}
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
d, err := m.store.GetCustomDomain(ctx, accountID, domainID)
if err != nil {
return fmt.Errorf("get domain from store: %w", err)
@@ -183,21 +153,6 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
}
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
return
}
if !ok {
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,
}).WithError(err).Error("validate domain")
}
log.WithFields(log.Fields{
"accountID": accountID,
"domainID": domainID,

View File

@@ -23,7 +23,7 @@ type Proxy struct {
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
Capabilities Capabilities `gorm:"embedded"`
CreatedAt time.Time
UpdatedAt time.Time

View File

@@ -6,12 +6,14 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -30,25 +32,19 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
}
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
domainmanager.RegisterEndpoints(domainRouter, domainManager)
domainmanager.RegisterEndpoints(domainRouter, domainManager, permissionsManager)
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
accesslogsmanager.RegisterEndpoints(router, accessLogsManager, permissionsManager)
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
router.HandleFunc("/reverse-proxies/clusters", permissionsManager.WithPermission(modules.Services, operations.Read, h.getClusters)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllServices)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Create, h.createService)).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Read, h.getService)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Update, h.updateService)).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteService)).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -63,13 +59,7 @@ func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiServices)
}
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.ServiceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -77,12 +67,13 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
}
service := new(rpservice.Service)
var err error
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
if err = service.Validate(); err != nil {
if err := service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -96,13 +87,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
}
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -118,13 +103,7 @@ func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
}
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -139,12 +118,13 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
service := new(rpservice.Service)
service.ID = serviceID
var err error
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
if err = service.Validate(); err != nil {
if err := service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -158,13 +138,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
}
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -179,13 +153,7 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)

View File

@@ -15,7 +15,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
@@ -86,18 +85,17 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
},
}
mgr := &Manager{
store: testStore,
accountManager: accountMgr,
permissionsManager: permissions.NewManager(testStore),
proxyController: mockCtrl,
capabilities: mockCaps,
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
store: testStore,
accountManager: accountMgr,
proxyController: mockCtrl,
capabilities: mockCaps,
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
}
mgr.exposeReaper = &exposeReaper{manager: mgr}

View File

@@ -21,9 +21,6 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -82,24 +79,22 @@ type CapabilityProvider interface {
}
type Manager struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
proxyController proxy.Controller
capabilities CapabilityProvider
clusterDeriver ClusterDeriver
exposeReaper *exposeReaper
store store.Store
accountManager account.Manager
proxyController proxy.Controller
capabilities CapabilityProvider
clusterDeriver ClusterDeriver
exposeReaper *exposeReaper
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
func NewManager(store store.Store, accountManager account.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
mgr := &Manager{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
proxyController: proxyController,
capabilities: capabilities,
clusterDeriver: clusterDeriver,
store: store,
accountManager: accountManager,
proxyController: proxyController,
capabilities: capabilities,
clusterDeriver: clusterDeriver,
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr
@@ -112,26 +107,10 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetActiveProxyClusters(ctx)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
@@ -185,14 +164,6 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
}
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
@@ -206,14 +177,6 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s
}
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
return nil, err
}
@@ -224,7 +187,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, s)
err := m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
@@ -491,14 +454,6 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
}
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := service.Auth.HashSecrets(); err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
@@ -785,16 +740,8 @@ func validateResourceTargetType(target *service.Target, resource *resourcetypes.
}
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var s *service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
@@ -825,16 +772,8 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
}
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var services []*service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
@@ -1119,7 +1058,7 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
}
groupIDs := make([]string, 0, len(groupNames))
for _, groupName := range groupNames {
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID, activity.SystemInitiator)
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
}

View File

@@ -23,9 +23,6 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
@@ -700,12 +697,10 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID)
require.NoError(t, err)
permsMgr := permissions.NewManager(testStore)
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
},
}
@@ -718,10 +713,9 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
require.NoError(t, err)
mgr := &Manager{
store: testStore,
accountManager: accountMgr,
permissionsManager: permsMgr,
proxyController: proxyController,
store: testStore,
accountManager: accountMgr,
proxyController: proxyController,
clusterDeriver: &testClusterDeriver{
domains: []string{"test.netbird.io"},
},
@@ -1130,7 +1124,6 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockPerms := permissions.NewMockManager(ctrl)
mockAcct := account.NewMockManager(ctrl)
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
@@ -1141,10 +1134,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err)
mgr := &Manager{
store: sqlStore,
permissionsManager: mockPerms,
accountManager: mockAcct,
proxyController: proxyController,
store: sqlStore,
accountManager: mockAcct,
proxyController: proxyController,
}
service := &rpservice.Service{
@@ -1167,9 +1159,6 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err)
require.Len(t, retrievedService.Targets, 3, "Service should have 3 targets before deletion")
mockPerms.EXPECT().
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
Return(true, nil)
mockAcct.EXPECT().
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
mockAcct.EXPECT().

View File

@@ -6,8 +6,11 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/zones"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -17,25 +20,19 @@ type handler struct {
manager zones.Manager
}
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
func RegisterEndpoints(router *mux.Router, manager zones.Manager, permissionsManager permissions.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllZones)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createZone)).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getZone)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateZone)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteZone)).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -50,13 +47,7 @@ func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiZones)
}
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.PostApiDnsZonesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -66,7 +57,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
zone := new(zones.Zone)
zone.FromAPIRequest(&req)
if err = zone.Validate(); err != nil {
if err := zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -80,13 +71,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
}
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -102,13 +87,7 @@ func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
}
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -116,7 +95,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
}
var req api.PutApiDnsZonesZoneIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
@@ -125,7 +104,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
zone.FromAPIRequest(&req)
zone.ID = zoneID
if err = zone.Validate(); err != nil {
if err := zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -139,20 +118,14 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
}
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
if err := h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -7,62 +7,34 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
dnsDomain string
store store.Store
accountManager account.Manager
dnsDomain string
}
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
func NewManager(store store.Store, accountManager account.Manager, dnsDomain string) zones.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
dnsDomain: dnsDomain,
store: store,
accountManager: accountManager,
dnsDomain: dnsDomain,
}
}
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
}
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var err error
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
return nil, err
}
@@ -102,14 +74,6 @@ func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string,
}
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
if err != nil {
return nil, fmt.Errorf("failed to get zone: %w", err)
@@ -150,14 +114,6 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
}
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)

View File

@@ -13,9 +13,6 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
@@ -29,7 +26,7 @@ const (
testDNSDomain = "netbird.selfhosted"
)
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *gomock.Controller, func()) {
t.Helper()
ctx := context.Background()
@@ -49,23 +46,17 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccoun
ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
dnsDomain: testDNSDomain,
}
manager := NewManager(testStore, mockAccountManager, testDNSDomain).(*managerImpl)
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
return manager, testStore, mockAccountManager, ctrl, cleanup
}
func TestManagerImpl_GetAllZones(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -77,10 +68,6 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
err = testStore.CreateZone(ctx, zone2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.NoError(t, err)
assert.Len(t, result, 2)
@@ -88,43 +75,13 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
assert.Equal(t, zone2.ID, result[1].ID)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_GetZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -132,10 +89,6 @@ func TestManagerImpl_GetZone(t *testing.T) {
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Equal(t, zone.ID, result.ID)
@@ -143,29 +96,13 @@ func TestManagerImpl_GetZone(t *testing.T) {
assert.Equal(t, zone.Domain, result.Domain)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
}
func TestManagerImpl_CreateZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, _, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -177,10 +114,6 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
@@ -199,31 +132,8 @@ func TestManagerImpl_CreateZone(t *testing.T) {
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("invalid group", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, _, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -233,17 +143,13 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{"invalid-group"},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
})
t.Run("duplicate domain", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -259,10 +165,6 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
@@ -273,7 +175,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
})
t.Run("peer DNS domain conflict", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -291,10 +193,6 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
@@ -305,7 +203,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
})
t.Run("default DNS domain conflict", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, _, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -317,10 +215,6 @@ func TestManagerImpl_CreateZone(t *testing.T) {
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
@@ -335,7 +229,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -352,10 +246,6 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
@@ -375,7 +265,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
})
t.Run("domain change not allowed", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -392,10 +282,6 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
@@ -405,31 +291,8 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
assert.Equal(t, status.InvalidArgument, s.Type())
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedZone := &zones.Zone{
ID: testZoneID,
Name: "Updated Name",
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, _, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -439,10 +302,6 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
@@ -453,7 +312,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
ctx := context.Background()
t.Run("success with records", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -469,10 +328,6 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCallCount := 0
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCallCount++
@@ -493,7 +348,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
})
t.Run("success without records", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -501,10 +356,6 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
@@ -522,31 +373,11 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
require.Error(t, err)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, _, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
require.Error(t, err)
})

View File

@@ -6,8 +6,11 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -17,25 +20,19 @@ type handler struct {
manager records.Manager
}
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
func RegisterEndpoints(router *mux.Router, manager records.Manager, permissionsManager permissions.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllRecords)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createRecord)).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getRecord)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateRecord)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteRecord)).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -56,13 +53,7 @@ func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiRecords)
}
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -78,7 +69,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
record := new(records.Record)
record.FromAPIRequest(&req)
if err = record.Validate(); err != nil {
if err := record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -92,13 +83,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
}
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -120,13 +105,7 @@ func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
}
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -140,7 +119,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
}
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
@@ -149,7 +128,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
record.FromAPIRequest(&req)
record.ID = recordID
if err = record.Validate(); err != nil {
if err := record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -163,13 +142,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
}
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -182,7 +155,7 @@ func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
return
}
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
if err := h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -9,64 +9,36 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
store store.Store
accountManager account.Manager
}
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
func NewManager(store store.Store, accountManager account.Manager) records.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
store: store,
accountManager: accountManager,
}
}
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
}
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
}
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
@@ -101,18 +73,11 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
}
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone
var record *records.Record
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
@@ -160,18 +125,11 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
}
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var record *records.Record
var zone *zones.Zone
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)

View File

@@ -12,12 +12,8 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -27,7 +23,7 @@ const (
testGroupID = "test-group-id"
)
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *gomock.Controller, func()) {
t.Helper()
ctx := context.Background()
@@ -51,22 +47,17 @@ func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_serv
ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
}
manager := NewManager(testStore, mockAccountManager).(*managerImpl)
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
return manager, testStore, zone, mockAccountManager, ctrl, cleanup
}
func TestManagerImpl_GetAllRecords(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -78,10 +69,6 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Len(t, result, 2)
@@ -89,43 +76,13 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
assert.Equal(t, record2.ID, result[1].ID)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_GetRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -133,10 +90,6 @@ func TestManagerImpl_GetRecord(t *testing.T) {
err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err)
assert.Equal(t, record.ID, result.ID)
@@ -146,29 +99,13 @@ func TestManagerImpl_GetRecord(t *testing.T) {
assert.Equal(t, record.TTL, result.TTL)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
}
func TestManagerImpl_CreateRecord(t *testing.T) {
ctx := context.Background()
t.Run("success - A record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -179,10 +116,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
@@ -202,7 +135,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
})
t.Run("success - AAAA record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -213,10 +146,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
@@ -231,7 +160,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
})
t.Run("success - CNAME record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, _, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -242,10 +171,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
@@ -259,32 +184,8 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
assert.Equal(t, inputRecord.Content, result.Content)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record name not in zone", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, _, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -295,10 +196,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
@@ -306,7 +203,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
})
t.Run("duplicate record", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -321,10 +218,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
@@ -332,7 +225,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
})
t.Run("CNAME conflict with existing A record", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -347,10 +240,6 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
@@ -362,7 +251,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -378,10 +267,6 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 600, // Changed TTL
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
@@ -400,7 +285,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
})
t.Run("update only TTL - no validation", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -416,10 +301,6 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 600, // Only TTL changed
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
// Event should be stored
}
@@ -430,33 +311,8 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
assert.Equal(t, 600, result.TTL)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedRecord := &records.Record{
ID: testRecordID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, _, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -468,17 +324,13 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
})
t.Run("update creates duplicate", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -498,10 +350,6 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
@@ -513,7 +361,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, testStore, zone, mockAccountManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
@@ -521,10 +369,6 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
@@ -542,31 +386,11 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
require.Error(t, err)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
manager, _, zone, _, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
require.Error(t, err)
})

View File

@@ -223,7 +223,7 @@ func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
return Create(s, func() accesslogs.Manager {
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.GeoLocationManager())
accessLogManager.StartPeriodicCleanup(
context.Background(),
s.Config.ReverseProxy.AccessLogRetentionDays,

View File

@@ -9,6 +9,7 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
@@ -27,7 +28,6 @@ import (
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users"
)
@@ -82,13 +82,13 @@ func (s *BaseServer) SettingsManager() settings.Manager {
idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled
}
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig)
return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, idpConfig)
})
}
func (s *BaseServer) PeersManager() peers.Manager {
return Create(s, func() peers.Manager {
manager := peers.NewManager(s.Store(), s.PermissionsManager())
manager := peers.NewManager(s.Store())
s.AfterInit(func(s *BaseServer) {
manager.SetNetworkMapController(s.NetworkMapController())
manager.SetIntegratedPeerValidator(s.IntegratedValidator())
@@ -161,43 +161,43 @@ func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
func (s *BaseServer) GroupsManager() groups.Manager {
return Create(s, func() groups.Manager {
return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
return groups.NewManager(s.Store(), s.AccountManager())
})
}
func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
return resources.NewManager(s.Store(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
})
}
func (s *BaseServer) RoutesManager() routers.Manager {
return Create(s, func() routers.Manager {
return routers.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
return routers.NewManager(s.Store(), s.AccountManager())
})
}
func (s *BaseServer) NetworksManager() networks.Manager {
return Create(s, func() networks.Manager {
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
return networks.NewManager(s.Store(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
})
}
func (s *BaseServer) ZonesManager() zones.Manager {
return Create(s, func() zones.Manager {
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.DNSDomain())
})
}
func (s *BaseServer) RecordsManager() records.Manager {
return Create(s, func() records.Manager {
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
return recordsManager.NewManager(s.Store(), s.AccountManager())
})
}
func (s *BaseServer) ServiceManager() service.Manager {
return Create(s, func() service.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
})
}
@@ -213,7 +213,7 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager(), s.AccountManager())
m := manager.NewManager(s.Store(), s.ProxyManager(), s.AccountManager())
return &m
})
}

View File

@@ -15,6 +15,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth"
@@ -39,9 +40,6 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -282,22 +280,14 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// User that performs the update has to belong to the account.
// Returns an updated Settings
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
var oldSettings *types.Settings
var updateAccountPeers bool
var groupChangesAffectPeers bool
var reloadReverseProxy bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var groupsUpdated bool
var err error
oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
@@ -725,15 +715,6 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return err
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
if err != nil {
return fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
}
userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users))
if err != nil {
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
@@ -976,6 +957,10 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
return nil, err
}
if user.AccountID != accountID {
return nil, nil
}
key := user.IntegrationReference.CacheKey(accountID, userID)
ud, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
@@ -1287,41 +1272,16 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
// GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccount(ctx, accountID)
}
// GetAccountMeta returns the account metadata associated with this account ID.
func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
}
// GetAccountOnboarding retrieves the onboarding information for a specific account.
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
log.Errorf("failed to get account onboarding for account %s: %v", accountID, err)
@@ -1338,15 +1298,6 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou
}
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
return nil, fmt.Errorf("failed to get account onboarding: %w", err)
@@ -1405,9 +1356,8 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return accountID, user.Id, nil
}
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return "", "", err
}
// Permission checks are now handled by the HTTP middleware via WithPermission wrapper
// User account association is already validated above by GetUserByUserID
if !user.IsServiceUser && userAuth.Invited {
err = am.redeemInvite(ctx, accountID, user.Id)
@@ -1849,13 +1799,6 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
}
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
}
@@ -2197,14 +2140,6 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee
}
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return fmt.Errorf("validate user permissions: %w", err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
updateNetworkMap, err := am.updatePeerIPInTransaction(ctx, accountID, userID, peerID, newIP)
if err != nil {
return fmt.Errorf("update peer IP transaction: %w", err)

View File

@@ -60,7 +60,7 @@ type Manager interface {
GetUserByID(ctx context.Context, id string) (*types.User, error)
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -75,7 +75,7 @@ type Manager interface {
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error

View File

@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
}
// GetGroupByName mocks base method.
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID)
ret0, _ := ret[0].(*types.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID)
}
// GetIdentityProvider mocks base method.
@@ -946,18 +946,18 @@ func (mr *MockManagerMockRecorder) GetPeerNetwork(ctx, peerID interface{}) *gomo
}
// GetPeers mocks base method.
func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*peer.Peer, error) {
func (m *MockManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string, all bool) ([]*peer.Peer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter)
ret := m.ctrl.Call(m, "GetPeers", ctx, accountID, userID, nameFilter, ipFilter, all)
ret0, _ := ret[0].([]*peer.Peer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeers indicates an expected call of GetPeers.
func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) GetPeers(ctx, accountID, userID, nameFilter, ipFilter, all interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeers", reflect.TypeOf((*MockManager)(nil).GetPeers), ctx, accountID, userID, nameFilter, ipFilter, all)
}
// GetPolicy mocks base method.

View File

@@ -22,9 +22,11 @@ import (
"go.opentelemetry.io/otel/metric/noop"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -49,7 +51,6 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -3124,7 +3125,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
AnyTimes()
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
peersManager := peers.NewManager(store)
proxyManager := proxy.NewMockManager(ctrl)
proxyManager.EXPECT().
@@ -3141,7 +3142,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, nil, err
@@ -3152,7 +3153,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
if err != nil {
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, proxyManager, nil))
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, proxyController, proxyManager, nil))
return manager, updateManager, nil
}

View File

@@ -8,8 +8,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
@@ -22,14 +20,6 @@ const (
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
}
@@ -39,18 +29,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
return err
}

View File

@@ -14,11 +14,11 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -28,7 +28,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -79,16 +78,6 @@ func TestGetDNSSettings(t *testing.T) {
if len(dnsSettings.DisabledManagementGroups) != 1 {
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
}
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
if err == nil {
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
}
s, ok := status.FromError(err)
if !ok && s.Type() != status.PermissionDenied {
t.Errorf("returned error should be Permission Denied, got err: %s", err)
}
}
func TestSaveDNSSettings(t *testing.T) {
@@ -223,7 +212,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
peersManager := peers.NewManager(store)
ctx := context.Background()
@@ -234,7 +223,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
}

View File

@@ -9,11 +9,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
func isEnabled() bool {
@@ -23,14 +20,6 @@ func isEnabled() bool {
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Events, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
if err != nil {
return nil, err

View File

@@ -12,8 +12,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
@@ -32,13 +30,24 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
// Permission checks are now handled by the HTTP middleware via WithPermission wrapper
// This method is called from authenticated/authorized handlers, so we just validate
// that the user exists and is part of the account
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return err
}
if !allowed {
return status.NewPermissionDeniedError()
if user == nil {
return status.NewUserNotFoundError(userID)
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
if user.IsBlocked() {
return status.NewUserBlockedError()
}
return nil
@@ -61,27 +70,17 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}
// CreateGroup object of the peers
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
@@ -125,19 +124,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
// UpdateGroup object of the peers
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
@@ -196,33 +187,24 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return err
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
if err := transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
@@ -243,6 +225,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
}
var err error
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
@@ -264,21 +247,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
@@ -311,6 +287,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
}
var err error
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
@@ -416,14 +393,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var allErrors error
var groupIDsToDelete []string
var deletedGroups []*types.Group

View File

@@ -26,7 +26,6 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -764,11 +763,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
// Saving a group linked to network router should update account peers and send peer update
t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
groupsManager := groups.NewManager(manager.Store, manager)
resourcesManager := resources.NewManager(manager.Store, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, manager)
networksManager := networks.NewManager(manager.Store, resourcesManager, routersManager, manager)
network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{
ID: "network_test",

View File

@@ -6,9 +6,6 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -25,31 +22,21 @@ type Manager interface {
}
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
accountManager account.Manager
store store.Store
accountManager account.Manager
}
type mockManager struct {
}
func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager account.Manager) Manager {
func NewManager(store store.Store, accountManager account.Manager) Manager {
return &managerImpl{
store: store,
permissionsManager: permissionsManager,
accountManager: accountManager,
store: store,
accountManager: accountManager,
}
}
func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
if err != nil {
return nil, err
}
if !ok {
return nil, err
}
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err)
@@ -73,14 +60,6 @@ func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID str
}
func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return err
}
if !ok {
return err
}
event, err := m.AddResourceToGroupInTransaction(ctx, m.store, accountID, userID, groupID, resource)
if err != nil {
return fmt.Errorf("error adding resource to group: %w", err)

View File

@@ -13,6 +13,7 @@ import (
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
"github.com/netbirdio/netbird/management/server/types"
@@ -34,10 +35,8 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/server/auth"
@@ -154,25 +153,25 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("failed to create instance manager: %w", err)
}
accounts.AddEndpoints(accountManager, settingsManager, router)
accounts.AddEndpoints(accountManager, settingsManager, router, permissionsManager)
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
users.AddEndpoints(accountManager, router)
users.AddInvitesEndpoints(accountManager, router)
users.AddEndpoints(accountManager, router, permissionsManager)
users.AddInvitesEndpoints(accountManager, router, permissionsManager)
users.AddPublicInvitesEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
setup_keys.AddEndpoints(accountManager, router, permissionsManager)
policies.AddEndpoints(accountManager, LocationManager, router, permissionsManager)
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router, permissionsManager)
policies.AddLocationsEndpoints(accountManager, LocationManager, permissionsManager, router)
groups.AddEndpoints(accountManager, router)
routes.AddEndpoints(accountManager, router)
dns.AddEndpoints(accountManager, router)
events.AddEndpoints(accountManager, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
zonesManager.RegisterEndpoints(router, zManager)
recordsManager.RegisterEndpoints(router, rManager)
idp.AddEndpoints(accountManager, router)
groups.AddEndpoints(accountManager, router, permissionsManager)
routes.AddEndpoints(accountManager, router, permissionsManager)
dns.AddEndpoints(accountManager, router, permissionsManager)
events.AddEndpoints(accountManager, router, permissionsManager)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, permissionsManager, router)
zonesManager.RegisterEndpoints(router, zManager, permissionsManager)
recordsManager.RegisterEndpoints(router, rManager, permissionsManager)
idp.AddEndpoints(accountManager, router, permissionsManager)
instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router, permissionsManager)
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
}

View File

@@ -12,10 +12,13 @@ import (
goversion "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -40,11 +43,11 @@ type handler struct {
settingsManager settings.Manager
}
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router, permissionsManager permissions.Manager) {
accountsHandler := newHandler(accountManager, settingsManager)
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Update, accountsHandler.updateAccount)).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Delete, accountsHandler.deleteAccount)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", permissionsManager.WithPermission(modules.Accounts, operations.Read, accountsHandler.getAllAccounts)).Methods("GET", "OPTIONS")
}
// newHandler creates a new handler HTTP handler
@@ -99,7 +102,7 @@ func (h *handler) validateNetworkRange(ctx context.Context, accountID, userID st
}
func (h *handler) validateCapacity(ctx context.Context, accountID, userID string, prefix netip.Prefix) error {
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "")
peers, err := h.accountManager.GetPeers(ctx, accountID, userID, "", "", true)
if err != nil {
return status.Errorf(status.Internal, "get peer count: %v", err)
}
@@ -136,34 +139,26 @@ func calculateRequiredAddresses(peerCount int) int64 {
}
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
meta, err := h.accountManager.GetAccountMeta(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
settings, err := h.settingsManager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
settings, err := h.settingsManager.GetSettings(r.Context(), accountID, userID)
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(accountID, settings, meta, onboarding)
resp := toAccountResponse(userAuth.AccountId, settings, meta, onboarding)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
@@ -233,24 +228,15 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
}
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
_, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
accountID := vars["accountId"]
if len(accountID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
accountID := mux.Vars(r)["accountId"]
if accountID != userAuth.AccountId {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "account ID mismatch"), w)
return
}
var req api.PutApiAccountsAccountIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -267,7 +253,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w)
return
}
if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil {
if err := h.validateNetworkRange(r.Context(), accountID, userAuth.UserId, prefix); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -282,19 +268,19 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
}
}
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding)
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userAuth.UserId, onboarding)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userAuth.UserId, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -306,21 +292,14 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
}
// deleteAccount is a HTTP DELETE handler to delete an account
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
accountID := mux.Vars(r)["accountId"]
if accountID != userAuth.AccountId {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "account ID mismatch"), w)
return
}
vars := mux.Vars(r)
targetAccountID := vars["accountId"]
if len(targetAccountID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w)
return
}
err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId)
err := h.accountManager.DeleteAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -14,6 +14,7 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/settings"
@@ -290,8 +291,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET")
router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT")
router.HandleFunc("/api/accounts", permissions.WrapHandler(handler.getAllAccounts)).Methods("GET")
router.HandleFunc("/api/accounts/{accountId}", permissions.WrapHandler(handler.updateAccount)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -5,11 +5,13 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -19,15 +21,15 @@ type dnsSettingsHandler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
addDNSSettingEndpoint(accountManager, router)
addDNSNameserversEndpoint(accountManager, router)
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
addDNSSettingEndpoint(accountManager, router, permissionsManager)
addDNSNameserversEndpoint(accountManager, router, permissionsManager)
}
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router) {
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
dnsSettingsHandler := newDNSSettingsHandler(accountManager)
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Read, dnsSettingsHandler.getDNSSettings)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Update, dnsSettingsHandler.updateDNSSettings)).Methods("PUT", "OPTIONS")
}
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
@@ -36,17 +38,8 @@ func newDNSSettingsHandler(accountManager account.Manager) *dnsSettingsHandler {
}
// getDNSSettings returns the DNS settings for the account
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -60,17 +53,9 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque
}
// updateDNSSettings handles update to DNS settings of an account
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.PutApiDnsSettingsJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -80,7 +65,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups,
}
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
err = h.accountManager.SaveDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId, updateDNSSettings)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -17,6 +17,7 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
@@ -115,8 +116,8 @@ func TestDNSSettingsHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")
router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT")
router.HandleFunc("/api/dns/settings", permissions.WrapHandler(p.getDNSSettings)).Methods("GET")
router.HandleFunc("/api/dns/settings", permissions.WrapHandler(p.updateDNSSettings)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -6,11 +6,13 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -21,13 +23,13 @@ type nameserversHandler struct {
accountManager account.Manager
}
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router) {
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
nameserversHandler := newNameserversHandler(accountManager)
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS")
router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getAllNameservers)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Create, nameserversHandler.createNameserverGroup)).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Update, nameserversHandler.updateNameserverGroup)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getNameserverGroup)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Delete, nameserversHandler.deleteNameserverGroup)).Methods("DELETE", "OPTIONS")
}
// newNameserversHandler returns a new instance of nameserversHandler handler
@@ -36,17 +38,8 @@ func newNameserversHandler(accountManager account.Manager) *nameserversHandler {
}
// getAllNameservers returns the list of nameserver groups for the account
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -61,17 +54,9 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re
}
// createNameserverGroup handles nameserver group creation request
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.PostApiDnsNameserversJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -83,7 +68,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), userAuth.AccountId, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userAuth.UserId, req.SearchDomainsEnabled)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -95,15 +80,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
}
// updateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
@@ -111,7 +88,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
}
var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -135,7 +112,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled,
}
err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
err = h.accountManager.SaveNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, updatedNSGroup)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -147,22 +124,14 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
}
// deleteNameserverGroup handles nameserver group deletion request
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
err := h.accountManager.DeleteNameServerGroup(r.Context(), userAuth.AccountId, nsGroupID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -172,22 +141,14 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt
}
// getNameserverGroup handles a nameserver group Get request identified by ID
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, nsGroupID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -18,6 +18,7 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
@@ -201,10 +202,10 @@ func TestNameserversHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET")
router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.getNameserverGroup)).Methods("GET")
router.HandleFunc("/api/dns/nameservers", permissions.WrapHandler(p.createNameserverGroup)).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.deleteNameserverGroup)).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", permissions.WrapHandler(p.updateNameserverGroup)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -5,11 +5,13 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -19,10 +21,10 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
eventsHandler := newHandler(accountManager)
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
router.HandleFunc("/events/audit", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
router.HandleFunc("/events", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
router.HandleFunc("/events/audit", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
}
// newHandler creates a new events handler
@@ -31,17 +33,8 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllEvents list of the given account
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
accountEvents, err := h.accountManager.GetEvents(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -13,6 +13,7 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
@@ -196,7 +197,7 @@ func TestEvents_GetEvents(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET")
router.HandleFunc("/api/events/", permissions.WrapHandler(handler.getAllEvents)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -7,11 +7,13 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -19,46 +21,45 @@ import (
// handler is a handler that returns groups of the account
type handler struct {
accountManager account.Manager
accountManager account.Manager
permissionsManager permissions.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
groupsHandler := newHandler(accountManager)
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS")
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
groupsHandler := newHandler(accountManager, permissionsManager)
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getAllGroups)).Methods("GET", "OPTIONS")
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Create, groupsHandler.createGroup)).Methods("POST", "OPTIONS")
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Update, groupsHandler.updateGroup)).Methods("PUT", "OPTIONS")
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getGroup)).Methods("GET", "OPTIONS")
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Delete, groupsHandler.deleteGroup)).Methods("DELETE", "OPTIONS")
}
// newHandler creates a new groups handler
func newHandler(accountManager account.Manager) *handler {
func newHandler(accountManager account.Manager, permissionsManager permissions.Manager) *handler {
return &handler{
accountManager: accountManager,
accountManager: accountManager,
permissionsManager: permissionsManager,
}
}
// getAllGroups list for the account
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) canReadPeers(r *http.Request, userAuth *auth.UserAuth) bool {
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Peers, operations.Read)
return err == nil && allowed
}
// getAllGroups list for the account
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
// Check if filtering by name
groupName := r.URL.Query().Get("name")
if groupName != "" {
// Get single group by name
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID)
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -71,13 +72,13 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
}
// Get all groups
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
groups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -92,15 +93,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
}
// updateGroup handles update to a group identified by a given ID
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
groupID, ok := vars["groupId"]
if !ok {
@@ -112,13 +105,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
return
}
existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
existingGroup, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID)
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -166,13 +159,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
IntegrationReference: existingGroup.IntegrationReference,
}
if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
if err := h.accountManager.UpdateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, userAuth.AccountId, err)
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -182,17 +175,9 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
}
// createGroup handles group creation request
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.PostApiGroupsJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -226,13 +211,13 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
Issued: types.GroupIssuedAPI,
}
err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group)
err = h.accountManager.CreateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -242,22 +227,14 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
}
// deleteGroup handles group deletion request
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
err := h.accountManager.DeleteGroup(r.Context(), userAuth.AccountId, userAuth.UserId, groupID)
if err != nil {
wrappedErr, ok := err.(interface{ Unwrap() []error })
if ok && len(wrappedErr.Unwrap()) > 0 {
@@ -273,34 +250,26 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
}
// getGroup returns a group
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
group, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
}
func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group {

View File

@@ -13,10 +13,14 @@ import (
"strings"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
@@ -33,8 +37,18 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
}
func initGroupTestData(initGroups ...*types.Group) *handler {
func initGroupTestData(t *testing.T, initGroups ...*types.Group) *handler {
t.Helper()
ctrl := gomock.NewController(t)
permissionsManagerMock := permissions.NewMockManager(ctrl)
permissionsManagerMock.EXPECT().
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)).
Return(true, nil).
AnyTimes()
return &handler{
permissionsManager: permissionsManagerMock,
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group, create bool) error {
if !strings.HasPrefix(group.ID, "id-") {
@@ -71,14 +85,14 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return groups, nil
},
GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) {
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
if groupName == "All" {
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
}
return nil, status.Errorf(status.NotFound, "unknown group name")
},
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
GetPeersFunc: func(ctx context.Context, accountID, userID, nameFilter, ipFilter string, _ bool) ([]*nbpeer.Peer, error) {
return maps.Values(TestPeers), nil
},
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
@@ -128,7 +142,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
p := initGroupTestData(group)
p := initGroupTestData(t, group)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -141,7 +155,7 @@ func TestGetGroup(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET")
router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.getGroup)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -254,7 +268,7 @@ func TestWriteGroup(t *testing.T) {
},
}
p := initGroupTestData()
p := initGroupTestData(t)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -267,8 +281,8 @@ func TestWriteGroup(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/groups", p.createGroup).Methods("POST")
router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT")
router.HandleFunc("/api/groups", permissions.WrapHandler(p.createGroup)).Methods("POST")
router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.updateGroup)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -332,7 +346,7 @@ func TestGetAllGroups(t *testing.T) {
},
}
p := initGroupTestData()
p := initGroupTestData(t)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -345,7 +359,7 @@ func TestGetAllGroups(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET")
router.HandleFunc("/api/groups", permissions.WrapHandler(p.getAllGroups)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -414,7 +428,7 @@ func TestDeleteGroup(t *testing.T) {
},
}
p := initGroupTestData()
p := initGroupTestData(t)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -426,7 +440,7 @@ func TestDeleteGroup(t *testing.T) {
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE")
router.HandleFunc("/api/groups/{groupId}", permissions.WrapHandler(p.deleteGroup)).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -6,9 +6,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -20,13 +23,13 @@ type handler struct {
}
// AddEndpoints registers identity provider endpoints
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
h := newHandler(accountManager)
router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS")
router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getAllIdentityProviders)).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Create, h.createIdentityProvider)).Methods("POST", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getIdentityProvider)).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Update, h.updateIdentityProvider)).Methods("PUT", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Delete, h.deleteIdentityProvider)).Methods("DELETE", "OPTIONS")
}
func newHandler(accountManager account.Manager) *handler {
@@ -36,16 +39,8 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllIdentityProviders returns all identity providers for the account
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID)
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
providers, err := h.accountManager.GetIdentityProviders(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -60,15 +55,7 @@ func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request
}
// getIdentityProvider returns a specific identity provider
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
@@ -76,7 +63,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
return
}
provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID)
provider, err := h.accountManager.GetIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -86,15 +73,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
}
// createIdentityProvider creates a new identity provider
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.IdentityProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -103,7 +82,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request)
idp := fromAPIRequest(&req)
created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp)
created, err := h.accountManager.CreateIdentityProvider(r.Context(), userAuth.AccountId, userAuth.UserId, idp)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -113,15 +92,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request)
}
// updateIdentityProvider updates an existing identity provider
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
@@ -137,7 +108,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request)
idp := fromAPIRequest(&req)
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp)
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId, idp)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -147,15 +118,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request)
}
// deleteIdentityProvider deletes an identity provider
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
@@ -163,7 +126,7 @@ func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request)
return
}
if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil {
if err := h.accountManager.DeleteIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
@@ -120,7 +121,7 @@ func TestGetAllIdentityProviders(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET")
router.HandleFunc("/api/identity-providers", permissions.WrapHandler(h.getAllIdentityProviders)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -180,7 +181,7 @@ func TestGetIdentityProvider(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET")
router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.getIdentityProvider)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -242,7 +243,7 @@ func TestCreateIdentityProvider(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST")
router.HandleFunc("/api/identity-providers", permissions.WrapHandler(h.createIdentityProvider)).Methods("POST")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -328,7 +329,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT")
router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.updateIdentityProvider)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -388,7 +389,7 @@ func TestDeleteIdentityProvider(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE")
router.HandleFunc("/api/identity-providers/{idpId}", permissions.WrapHandler(h.deleteIdentityProvider)).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -7,7 +7,11 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -29,12 +33,12 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
}
// AddVersionEndpoint registers the authenticated version endpoint.
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router, permissionsManager permissions.Manager) {
h := &handler{
instanceManager: instanceManager,
}
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
router.HandleFunc("/instance/version", permissionsManager.WithPermission(modules.Settings, operations.Read, h.getVersionInfo)).Methods("GET", "OPTIONS")
}
// getInstanceStatus returns the instance status including whether setup is required.
@@ -77,7 +81,7 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
// getVersionInfo returns version information for NetBird components.
// This endpoint requires authentication.
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)

View File

@@ -10,12 +10,17 @@ import (
"net/mail"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/idp"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -295,8 +300,15 @@ func TestSetup_ManagerError(t *testing.T) {
func TestGetVersionInfo_Success(t *testing.T) {
manager := &mockInstanceManager{}
ctrl := gomock.NewController(t)
permissionsManager := permissions.NewMockManager(ctrl)
permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handler(w, r, &auth.UserAuth{})
}
}).AnyTimes()
router := mux.NewRouter()
AddVersionEndpoint(manager, router)
AddVersionEndpoint(manager, router, permissionsManager)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()
@@ -323,8 +335,15 @@ func TestGetVersionInfo_Error(t *testing.T) {
return nil, errors.New("failed to fetch versions")
},
}
ctrl := gomock.NewController(t)
permissionsManager := permissions.NewMockManager(ctrl)
permissionsManager.EXPECT().WithPermission(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(module modules.Module, operation operations.Operation, handler func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth), authErrHandler ...permissions.AuthErrorHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
handler(w, r, &auth.UserAuth{})
}
}).AnyTimes()
router := mux.NewRouter()
AddVersionEndpoint(manager, router)
AddVersionEndpoint(manager, router, permissionsManager)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()

View File

@@ -9,8 +9,10 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
@@ -18,6 +20,7 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/networks/types"
nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -33,16 +36,16 @@ type handler struct {
groupsManager groups.Manager
}
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, router *mux.Router) {
addRouterEndpoints(routerManager, router)
addResourceEndpoints(resourceManager, groupsManager, router)
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, permissionsManager permissions.Manager, router *mux.Router) {
addRouterEndpoints(routerManager, permissionsManager, router)
addResourceEndpoints(resourceManager, groupsManager, permissionsManager, router)
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager)
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getAllNetworks)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Create, networksHandler.createNetwork)).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getNetwork)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Update, networksHandler.updateNetwork)).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, networksHandler.deleteNetwork)).Methods("DELETE", "OPTIONS")
}
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager) *handler {
@@ -55,40 +58,32 @@ func newHandler(networksManager networks.Manager, resourceManager resources.Mana
}
}
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networks, err := h.networksManager.GetAllNetworks(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID)
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID)
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), accountID)
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -97,16 +92,9 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account))
}
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.NetworkRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -115,14 +103,14 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
network := &types.Network{}
network.FromAPIRequest(&req)
network.AccountID = accountID
network, err = h.networksManager.CreateNetwork(r.Context(), userID, network)
network.AccountID = userAuth.AccountId
network, err = h.networksManager.CreateNetwork(r.Context(), userAuth.UserId, network)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), accountID)
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -133,14 +121,7 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs))
}
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
@@ -148,19 +129,19 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
return
}
network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID)
network, err := h.networksManager.GetNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), accountID)
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -171,14 +152,7 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
}
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
@@ -187,7 +161,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
}
var req api.NetworkRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -197,20 +171,20 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
network.FromAPIRequest(&req)
network.ID = networkID
network.AccountID = accountID
network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network)
network.AccountID = userAuth.AccountId
network, err = h.networksManager.UpdateNetwork(r.Context(), userAuth.UserId, network)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), accountID)
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -221,14 +195,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
}
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
@@ -236,7 +203,7 @@ func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
return
}
err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID)
err := h.networksManager.DeleteNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -6,10 +6,13 @@ import (
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -19,14 +22,14 @@ type resourceHandler struct {
groupsManager groups.Manager
}
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) {
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, permissionsManager permissions.Manager, router *mux.Router) {
resourceHandler := newResourceHandler(resourcesManager, groupsManager)
router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS")
router.HandleFunc("/networks/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInAccount)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInNetwork)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Create, resourceHandler.createResource)).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getResource)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Update, resourceHandler.updateResource)).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, resourceHandler.deleteResource)).Methods("DELETE", "OPTIONS")
}
func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler {
@@ -36,22 +39,15 @@ func newResourceHandler(resourceManager resources.Manager, groupsManager groups.
}
}
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networkID := mux.Vars(r)["networkId"]
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID)
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -66,22 +62,14 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, resourcesResponse)
}
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -97,17 +85,9 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, resourcesResponse)
}
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.NetworkResourceRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -117,14 +97,14 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
resource.FromAPIRequest(&req)
resource.NetworkID = mux.Vars(r)["networkId"]
resource.AccountID = accountID
resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource)
resource.AccountID = userAuth.AccountId
resource, err = h.resourceManager.CreateResource(r.Context(), userAuth.UserId, resource)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -135,23 +115,16 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
}
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"]
resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID)
resource, err := h.resourceManager.GetResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -162,16 +135,9 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
}
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.NetworkResourceRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -182,14 +148,14 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
resource.ID = mux.Vars(r)["resourceId"]
resource.NetworkID = mux.Vars(r)["networkId"]
resource.AccountID = accountID
resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource)
resource.AccountID = userAuth.AccountId
resource, err = h.resourceManager.UpdateResource(r.Context(), userAuth.UserId, resource)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -200,17 +166,10 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
}
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"]
err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID)
err := h.resourceManager.DeleteResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -6,9 +6,12 @@ import (
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -17,14 +20,14 @@ type routersHandler struct {
routersManager routers.Manager
}
func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) {
func addRouterEndpoints(routersManager routers.Manager, permissionsManager permissions.Manager, router *mux.Router) {
routersHandler := newRoutersHandler(routersManager)
router.HandleFunc("/networks/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getNetworkRouters).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS")
router.HandleFunc("/networks/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getAllRouters)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getNetworkRouters)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Create, routersHandler.createRouter)).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getRouter)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Update, routersHandler.updateRouter)).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, routersHandler.deleteRouter)).Methods("DELETE", "OPTIONS")
}
func newRoutersHandler(routersManager routers.Manager) *routersHandler {
@@ -33,16 +36,8 @@ func newRoutersHandler(routersManager routers.Manager) *routersHandler {
}
}
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -58,17 +53,9 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, routersResponse)
}
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networkID := mux.Vars(r)["networkId"]
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID)
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -82,18 +69,10 @@ func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Reques
util.WriteJSONObject(r.Context(), w, routersResponse)
}
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networkID := mux.Vars(r)["networkId"]
var req api.NetworkRouterRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -103,7 +82,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
router.FromAPIRequest(&req)
router.NetworkID = networkID
router.AccountID = accountID
router.AccountID = userAuth.AccountId
router.Enabled = true
if err := router.Validate(); err != nil {
@@ -111,7 +90,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
return
}
router, err = h.routersManager.CreateRouter(r.Context(), userID, router)
router, err = h.routersManager.CreateRouter(r.Context(), userAuth.UserId, router)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -120,18 +99,10 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
}
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"]
router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID)
router, err := h.routersManager.GetRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -140,17 +111,9 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
}
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.NetworkRouterRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -161,14 +124,14 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
router.NetworkID = mux.Vars(r)["networkId"]
router.ID = mux.Vars(r)["routerId"]
router.AccountID = accountID
router.AccountID = userAuth.AccountId
if err := router.Validate(); err != nil {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
router, err = h.routersManager.UpdateRouter(r.Context(), userID, router)
router, err = h.routersManager.UpdateRouter(r.Context(), userAuth.UserId, router)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -177,17 +140,10 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
}
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"]
err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID)
err := h.routersManager.DeleteRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -1,7 +1,6 @@
package peers
import (
"context"
"encoding/json"
"fmt"
"net/http"
@@ -12,15 +11,15 @@ import (
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -35,14 +34,15 @@ type Handler struct {
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS")
router.HandleFunc("/peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAllPeers)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetPeer)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Update, peersHandler.UpdatePeer)).Methods("PUT", "OPTIONS")
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Delete, peersHandler.DeletePeer)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/peers/{peerId}/accessible-peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAccessiblePeers)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/temporary-access", permissionsManager.WithPermission(modules.Peers, operations.Create, peersHandler.CreateTemporaryAccess)).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.ListJobs)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Create, peersHandler.CreateJob)).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.GetJob)).Methods("GET", "OPTIONS")
}
// NewHandler creates a new peers Handler
@@ -54,14 +54,7 @@ func NewHandler(accountManager account.Manager, networkMapController network_map
}
}
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
@@ -73,37 +66,30 @@ func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
util.WriteError(ctx, err, w)
if err := h.accountManager.CreatePeerJob(r.Context(), userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(ctx, w, resp)
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID)
jobs, err := h.accountManager.GetAllPeerJobs(r.Context(), userAuth.AccountId, userAuth.UserId, peerID)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
@@ -111,79 +97,88 @@ func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
for _, job := range jobs {
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
respBody = append(respBody, resp)
}
util.WriteJSONObject(ctx, w, respBody)
util.WriteJSONObject(r.Context(), w, respBody)
}
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
jobID := vars["jobId"]
job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID)
job, err := h.accountManager.GetPeerJobByID(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, jobID)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(ctx, w, resp)
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
// GetPeer handles GET request for a single peer
func (h *Handler) GetPeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
peer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
if peer.ProxyMeta.Embedded {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
return
}
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
grps, _ := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peerID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
if err != nil {
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
_, valid := validPeers[peer.ID]
reason := invalidPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
// UpdatePeer handles PUT request to update a peer
func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -192,11 +187,10 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
}
update := &nbpeer.Peer{
ID: peerID,
SSHEnabled: req.SshEnabled,
Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled,
ID: peerID,
SSHEnabled: req.SshEnabled,
Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled,
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
}
@@ -210,41 +204,41 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
if req.Ip != nil {
addr, err := netip.ParseAddr(*req.Ip)
if err != nil {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
return
}
if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil {
util.WriteError(ctx, err, w)
if err = h.accountManager.UpdatePeerIP(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, addr); err != nil {
util.WriteError(r.Context(), err, w)
return
}
}
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
peer, err := h.accountManager.UpdatePeer(r.Context(), userAuth.AccountId, userAuth.UserId, update)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
dnsDomain := h.networkMapController.GetDNSDomain(settings)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peer.ID)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
@@ -254,25 +248,8 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(ctx, w, util.EmptyObject{})
}
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
// DeletePeer handles DELETE request to delete a peer
func (h *Handler) DeletePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
@@ -280,48 +257,34 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
return
}
switch r.Method {
case http.MethodDelete:
h.deletePeer(r.Context(), accountID, userID, peerID, w)
err := h.accountManager.DeletePeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to delete peer: %v", err)
util.WriteError(r.Context(), err, w)
return
case http.MethodGet:
h.getPeer(r.Context(), accountID, peerID, userID, w)
return
case http.MethodPut:
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
return
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
// GetAllPeers returns a list of all peers associated with a provided account
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
nameFilter := r.URL.Query().Get("name")
ipFilter := r.URL.Query().Get("ip")
accountID, userID := userAuth.AccountId, userAuth.UserId
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, nameFilter, ipFilter)
peers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, nameFilter, ipFilter, true)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator)
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
grps, _ := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
respBody := make([]*api.PeerBatch, 0, len(peers))
@@ -332,7 +295,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
}
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -356,15 +319,7 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersM
}
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
@@ -372,25 +327,22 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userID)
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read)
if err != nil {
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
account, err := h.accountManager.GetAccountByID(r.Context(), userAuth.AccountId, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !allowed && !userAuth.IsChild {
// Check if user is an admin/service user through their role
isAdmin := user.Role == types.UserRoleAdmin || user.Role == types.UserRoleOwner
if !isAdmin && !user.IsServiceUser && !userAuth.IsChild {
if account.Settings.RegularUsersViewBlocked {
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
return
@@ -408,7 +360,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
}
}
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -422,13 +374,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
@@ -437,7 +383,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
}
var req api.PeerTemporaryAccessRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return

View File

@@ -19,11 +19,11 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -174,7 +174,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
return nil, fmt.Errorf("user not found")
}
},
GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
GetPeersFunc: func(_ context.Context, accountID, userID, nameFilter, ipFilter string, _ bool) ([]*nbpeer.Peer, error) {
return peers, nil
},
GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
@@ -307,9 +307,9 @@ func TestGetPeers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT")
router.HandleFunc("/api/peers/", permissions.WrapHandler(p.GetAllPeers)).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", permissions.WrapHandler(p.GetPeer)).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", permissions.WrapHandler(p.UpdatePeer)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -498,7 +498,7 @@ func TestGetAccessiblePeers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
router.HandleFunc("/api/peers/{peerId}/accessible-peers", permissions.WrapHandler(p.GetAccessiblePeers)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -582,7 +582,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
rr := httptest.NewRecorder()
router := mux.NewRouter()
router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT")
router.HandleFunc("/peers/{peerId}", permissions.WrapHandler(p.UpdatePeer)).Methods("PUT")
router.ServeHTTP(rr, req)

View File

@@ -14,12 +14,12 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -121,7 +121,7 @@ func TestGetCitiesByCountry(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET")
router.HandleFunc("/api/locations/countries/{country}/cities", permissions.WrapHandler(geolocationHandler.getCitiesByCountry)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -214,7 +214,7 @@ func TestGetAllCountries(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET")
router.HandleFunc("/api/locations/countries", permissions.WrapHandler(geolocationHandler.getAllCountries)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -6,12 +6,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -30,8 +30,8 @@ type geolocationsHandler struct {
func AddLocationsEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, permissionsManager permissions.Manager, router *mux.Router) {
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, permissionsManager)
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getAllCountries)).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getCitiesByCountry)).Methods("GET", "OPTIONS")
}
// newGeolocationsHandlerHandler creates a new Geolocations handler
@@ -44,12 +44,7 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
}
// getAllCountries retrieves a list of all countries
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
@@ -70,12 +65,7 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
}
// getCitiesByCountry retrieves a list of cities based on the given country code
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) {
@@ -102,27 +92,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
util.WriteJSONObject(r.Context(), w, cities)
}
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
return err
}
accountID, userID := userAuth.AccountId, userAuth.UserId
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
return nil
}
func toCountryResponse(country geolocation.Country) api.Country {
return api.Country{
CountryName: country.CountryName,

View File

@@ -7,10 +7,13 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -21,13 +24,13 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
policiesHandler := newHandler(accountManager)
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS")
router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getAllPolicies)).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Create, policiesHandler.createPolicy)).Methods("POST", "OPTIONS")
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Update, policiesHandler.updatePolicy)).Methods("PUT", "OPTIONS")
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getPolicy)).Methods("GET", "OPTIONS")
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, policiesHandler.deletePolicy)).Methods("DELETE", "OPTIONS")
}
// newHandler creates a new policies handler
@@ -38,22 +41,14 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllPolicies list for the account
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
listPolicies, err := h.accountManager.ListPolicies(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -73,15 +68,7 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
}
// updatePolicy handles update to a policy identified by a given ID
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@@ -89,26 +76,18 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
_, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
h.savePolicy(w, r, accountID, userID, policyID, false)
h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, policyID, false)
}
// createPolicy handles policy creation request
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
h.savePolicy(w, r, accountID, userID, "", true)
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, "", true)
}
// savePolicy handles policy creation and update
@@ -303,14 +282,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
}
// deletePolicy handles policy deletion request
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@@ -318,7 +290,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
return
}
if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
if err := h.accountManager.DeletePolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -327,15 +299,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
}
// getPolicy handles a group Get request identified by ID
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@@ -343,13 +307,13 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
return
}
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
policy, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -13,6 +13,7 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
@@ -111,7 +112,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET")
router.HandleFunc("/api/policies/{policyId}", permissions.WrapHandler(p.getPolicy)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -275,8 +276,8 @@ func TestPoliciesWritePolicy(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/policies", p.createPolicy).Methods("POST")
router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT")
router.HandleFunc("/api/policies", permissions.WrapHandler(p.createPolicy)).Methods("POST")
router.HandleFunc("/api/policies/{policyId}", permissions.WrapHandler(p.updatePolicy)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -6,10 +6,13 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -21,13 +24,13 @@ type postureChecksHandler struct {
geolocationManager geolocation.Geolocation
}
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager)
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS")
router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getAllPostureChecks)).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Create, postureCheckHandler.createPostureCheck)).Methods("POST", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Update, postureCheckHandler.updatePostureCheck)).Methods("PUT", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getPostureCheck)).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, postureCheckHandler.deletePostureCheck)).Methods("DELETE", "OPTIONS")
}
// newPostureChecksHandler creates a new PostureChecks handler
@@ -39,15 +42,8 @@ func newPostureChecksHandler(accountManager account.Manager, geolocationManager
}
// getAllPostureChecks list for the account
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -62,15 +58,7 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt
}
// updatePostureCheck handles update to a posture check identified by a given ID
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
@@ -78,37 +66,22 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http
return
}
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
_, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
p.savePostureChecks(w, r, accountID, userID, postureChecksID, false)
p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, postureChecksID, false)
}
// createPostureCheck handles posture check creation request
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
p.savePostureChecks(w, r, accountID, userID, "", true)
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, "", true)
}
// getPostureCheck handles a posture check Get request identified by ID
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
@@ -116,7 +89,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
return
}
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -126,14 +99,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
}
// deletePostureCheck handles posture check deletion request
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
@@ -141,7 +107,7 @@ func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http
return
}
if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
if err := p.accountManager.DeletePostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/mock_server"
@@ -183,7 +184,7 @@ func TestGetPostureCheck(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET")
router.HandleFunc("/api/posture-checks/{postureCheckId}", permissions.WrapHandler(p.getPostureCheck)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -841,8 +842,8 @@ func TestPostureCheckUpdate(t *testing.T) {
}
router := mux.NewRouter()
router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST")
router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT")
router.HandleFunc("/api/posture-checks", permissions.WrapHandler(defaultHandler.createPostureCheck)).Methods("POST")
router.HandleFunc("/api/posture-checks/{postureCheckId}", permissions.WrapHandler(defaultHandler.updatePostureCheck)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -8,9 +8,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -26,13 +29,13 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
routesHandler := newHandler(accountManager)
router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS")
router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS")
router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getAllRoutes)).Methods("GET", "OPTIONS")
router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Create, routesHandler.createRoute)).Methods("POST", "OPTIONS")
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Update, routesHandler.updateRoute)).Methods("PUT", "OPTIONS")
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getRoute)).Methods("GET", "OPTIONS")
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Delete, routesHandler.deleteRoute)).Methods("DELETE", "OPTIONS")
}
// newHandler returns a new instance of routes handler
@@ -43,16 +46,8 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllRoutes returns the list of routes for the account
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routes, err := h.accountManager.ListRoutes(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -71,17 +66,9 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
}
// createRoute handles route creation request
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.PostApiRoutesJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -134,8 +121,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
skipAutoApply = false
}
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply)
newRoute, err := h.accountManager.CreateRoute(r.Context(), userAuth.AccountId, newPrefix, networkType, domains, peerId, peerGroupIds,
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userAuth.UserId, req.KeepRoute, skipAutoApply)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -185,14 +172,7 @@ func (h *handler) validateRouteCommon(network *string, domains *[]string, peer *
}
// updateRoute handles update to a route identified by a given ID
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
routeID := vars["routeId"]
if len(routeID) == 0 {
@@ -200,7 +180,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
_, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -271,7 +251,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
newRoute.AccessControlGroups = *req.AccessControlGroups
}
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
err = h.accountManager.SaveRoute(r.Context(), userAuth.AccountId, userAuth.UserId, newRoute)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -287,21 +267,14 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
}
// deleteRoute handles route deletion request
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
err := h.accountManager.DeleteRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -311,22 +284,14 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
}
// getRoute handles a route Get request identified by ID
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
foundRoute, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/util"
@@ -501,10 +502,10 @@ func TestRoutesHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET")
router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE")
router.HandleFunc("/api/routes", p.createRoute).Methods("POST")
router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT")
router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.getRoute)).Methods("GET")
router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.deleteRoute)).Methods("DELETE")
router.HandleFunc("/api/routes", permissions.WrapHandler(p.createRoute)).Methods("POST")
router.HandleFunc("/api/routes/{routeId}", permissions.WrapHandler(p.updateRoute)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -8,9 +8,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -21,13 +24,13 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
keysHandler := newHandler(accountManager)
router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS")
router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getAllSetupKeys)).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Create, keysHandler.createSetupKey)).Methods("POST", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getSetupKey)).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Update, keysHandler.updateSetupKey)).Methods("PUT", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Delete, keysHandler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
}
// newHandler creates a new setup key handler
@@ -38,16 +41,9 @@ func newHandler(accountManager account.Manager) *handler {
}
// createSetupKey is a POST requests that creates a new SetupKey
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
req := &api.PostApiSetupKeysJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -85,8 +81,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
allowExtraDNSLabels = *req.AllowExtraDnsLabels
}
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, userID, ephemeral, allowExtraDNSLabels)
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), userAuth.AccountId, req.Name, types.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, userAuth.UserId, ephemeral, allowExtraDNSLabels)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -100,14 +96,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
}
// getSetupKey is a GET request to get a SetupKey by ID
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
@@ -115,7 +104,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
return
}
key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
key, err := h.accountManager.GetSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -125,14 +114,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
}
// updateSetupKey is a PUT request to update server.SetupKey
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
@@ -141,7 +123,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
}
req := &api.PutApiSetupKeysKeyIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -157,7 +139,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
newKey.Revoked = req.Revoked
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
newKey, err = h.accountManager.SaveSetupKey(r.Context(), userAuth.AccountId, newKey, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -166,15 +148,8 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
}
// getAllSetupKeys is a GET request that returns a list of SetupKey
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -188,14 +163,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
}
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
@@ -203,7 +171,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID)
err := h.accountManager.DeleteSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -14,6 +14,7 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
@@ -171,11 +172,11 @@ func TestSetupKeysHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS")
router.HandleFunc("/api/setup-keys", permissions.WrapHandler(handler.getAllSetupKeys)).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys", permissions.WrapHandler(handler.createSetupKey)).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.getSetupKey)).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.updateSetupKey)).Methods("PUT", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", permissions.WrapHandler(handler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -9,10 +9,13 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -55,14 +58,14 @@ type invitesHandler struct {
}
// AddInvitesEndpoints registers invite-related endpoints
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
h := &invitesHandler{accountManager: accountManager}
// Authenticated endpoints (require admin)
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Read, h.listInvites)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Create, h.createInvite)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", permissionsManager.WithPermission(modules.Users, operations.Delete, h.deleteInvite)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", permissionsManager.WithPermission(modules.Users, operations.Update, h.regenerateInvite)).Methods("POST", "OPTIONS")
}
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
@@ -79,14 +82,7 @@ func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Route
}
// listInvites handles GET /api/users/invites
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -102,14 +98,7 @@ func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
}
// createInvite handles POST /api/users/invites
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.UserInviteCreateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -191,18 +180,12 @@ func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
}
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
@@ -238,14 +221,7 @@ func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request
}
// deleteInvite handles DELETE /api/users/invites/{inviteId}
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
@@ -253,7 +229,7 @@ func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
err := h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -110,7 +110,11 @@ func TestListInvites(t *testing.T) {
})
rr := httptest.NewRecorder()
handler.listInvites(rr, req)
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.listInvites(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -235,7 +239,11 @@ func TestCreateInvite(t *testing.T) {
})
rr := httptest.NewRecorder()
handler.createInvite(rr, req)
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.createInvite(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -573,7 +581,11 @@ func TestRegenerateInvite(t *testing.T) {
}
rr := httptest.NewRecorder()
handler.regenerateInvite(rr, req)
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.regenerateInvite(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -651,7 +663,11 @@ func TestDeleteInvite(t *testing.T) {
}
rr := httptest.NewRecorder()
handler.deleteInvite(rr, req)
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.deleteInvite(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code)
})

View File

@@ -6,9 +6,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -19,12 +22,12 @@ type patHandler struct {
accountManager account.Manager
}
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router) {
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
tokenHandler := newPATsHandler(accountManager)
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getAllTokens)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Create, tokenHandler.createToken)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getToken)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Delete, tokenHandler.deleteToken)).Methods("DELETE", "OPTIONS")
}
// newPATsHandler creates a new patHandler HTTP handler
@@ -35,22 +38,15 @@ func newPATsHandler(accountManager account.Manager) *patHandler {
}
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(userID) == 0 {
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
pats, err := h.accountManager.GetAllPATs(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -65,14 +61,7 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
}
// getToken is HTTP GET handler that returns a personal access token for the given user
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -86,7 +75,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
return
}
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
pat, err := h.accountManager.GetPAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -96,14 +85,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
}
// createToken is HTTP POST handler that creates a personal access token for the given user
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -112,13 +94,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
}
var req api.PostApiUsersUserIdTokensJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
pat, err := h.accountManager.CreatePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.Name, req.ExpiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -128,14 +110,7 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
}
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -149,7 +124,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
err := h.accountManager.DeletePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
@@ -181,10 +182,10 @@ func TestTokenHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
router.HandleFunc("/api/users/{userId}/tokens", permissions.WrapHandler(p.getAllTokens)).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", permissions.WrapHandler(p.getToken)).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens", permissions.WrapHandler(p.createToken)).Methods("POST")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", permissions.WrapHandler(p.deleteToken)).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -8,14 +8,16 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
nbcontext "github.com/netbirdio/netbird/management/server/context"
)
// handler is a handler that returns users of the account
@@ -23,18 +25,18 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
userHandler := newHandler(accountManager)
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS")
addUsersTokensEndpoint(accountManager, router)
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getAllUsers, userHandler.getOwnUser)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/current", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getCurrentUser)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.updateUser)).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.deleteUser)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.createUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.inviteUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/approve", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.approveUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/reject", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.rejectUser)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/password", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.changePassword)).Methods("PUT", "OPTIONS")
addUsersTokensEndpoint(accountManager, router, permissionsManager)
}
// newHandler creates a new UsersHandler HTTP handler
@@ -45,19 +47,12 @@ func newHandler(accountManager account.Manager) *handler {
}
// updateUser is a PUT requests to update User data
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -71,6 +66,11 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
return
}
if existingUser.AccountID != userAuth.AccountId {
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "user not found"), w)
return
}
req := &api.PutApiUsersUserIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -89,7 +89,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
return
}
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{
newUser, err := h.accountManager.SaveUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.User{
Id: targetUserID,
Role: userRole,
AutoGroups: req.AutoGroups,
@@ -102,23 +102,16 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
}
// deleteUser is a DELETE request to delete a user
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -126,7 +119,7 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
err := h.accountManager.DeleteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -136,21 +129,14 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
}
// createUser creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
req := &api.PostApiUsersJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -171,7 +157,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name
}
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{
newUser, err := h.accountManager.CreateUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.UserInfo{
Email: email,
Name: name,
Role: req.Role,
@@ -183,25 +169,18 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
}
// getAllUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager.
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
data, err := h.accountManager.GetUsersFromAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -215,7 +194,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
continue
}
if serviceUser == "" {
users = append(users, toUserResponse(d, userID))
users = append(users, toUserResponse(d, userAuth.UserId))
continue
}
@@ -226,7 +205,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
return
}
if includeServiceUser == d.IsServiceUser {
users = append(users, toUserResponse(d, userID))
users = append(users, toUserResponse(d, userAuth.UserId))
}
}
@@ -235,19 +214,12 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
// inviteUser resend invitations to users who haven't activated their accounts,
// prior to the expiration period.
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -255,7 +227,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
err := h.accountManager.InviteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -264,19 +236,13 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth)
user, err := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -356,7 +322,7 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
}
// approveUser is a POST request to approve a user that is pending approval
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
@@ -369,11 +335,6 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -385,7 +346,7 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
}
// rejectUser is a DELETE request to reject a user that is pending approval
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
@@ -398,12 +359,7 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
err := h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -421,7 +377,7 @@ type passwordChangeRequest struct {
// changePassword is a PUT request to change user's password.
// Only available when embedded IDP is enabled.
// Users can only change their own password.
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
@@ -434,19 +390,13 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req passwordChangeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
err := h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -454,3 +404,19 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func (h *handler) getOwnUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth, err error) bool {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.PermissionDenied {
return false
}
user, userErr := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
if userErr != nil {
util.WriteError(r.Context(), userErr, w)
return true
}
util.WriteJSONObject(r.Context(), w, []*api.User{toUserResponse(user.UserInfo, userAuth.UserId)})
return true
}

View File

@@ -15,10 +15,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
roles2 "github.com/netbirdio/netbird/management/internals/modules/permissions/roles"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/auth"
@@ -151,7 +152,7 @@ func initUsersTestData() *handler {
NonDeletable: false,
Issued: "api",
},
Permissions: mergeRolePermissions(roles.Owner),
Permissions: mergeRolePermissions(roles2.Owner),
}, nil
case "regular-user":
return &users.UserInfoWithPermissions{
@@ -165,7 +166,7 @@ func initUsersTestData() *handler {
NonDeletable: false,
Issued: "api",
},
Permissions: mergeRolePermissions(roles.User),
Permissions: mergeRolePermissions(roles2.User),
}, nil
case "admin-user":
@@ -181,7 +182,7 @@ func initUsersTestData() *handler {
LastLogin: time.Time{},
Issued: "api",
},
Permissions: mergeRolePermissions(roles.Admin),
Permissions: mergeRolePermissions(roles2.Admin),
}, nil
case "restricted-user":
return &users.UserInfoWithPermissions{
@@ -196,7 +197,7 @@ func initUsersTestData() *handler {
LastLogin: time.Time{},
Issued: "api",
},
Permissions: mergeRolePermissions(roles.User),
Permissions: mergeRolePermissions(roles2.User),
Restricted: true,
}, nil
}
@@ -232,7 +233,11 @@ func TestGetUsers(t *testing.T) {
AccountId: existingAccountID,
})
userHandler.getAllUsers(recorder, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.getAllUsers(recorder, req, userAuth)
res := recorder.Result()
defer res.Body.Close()
@@ -343,7 +348,7 @@ func TestUpdateUser(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
router.HandleFunc("/api/users/{userId}", permissions.WrapHandler(userHandler.updateUser)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -439,7 +444,11 @@ func TestCreateUser(t *testing.T) {
AccountId: existingAccountID,
})
userHandler.createUser(rr, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.createUser(rr, req, userAuth)
res := rr.Result()
defer res.Body.Close()
@@ -490,7 +499,11 @@ func TestInviteUser(t *testing.T) {
rr := httptest.NewRecorder()
userHandler.inviteUser(rr, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.inviteUser(rr, req, userAuth)
res := rr.Result()
defer res.Body.Close()
@@ -549,7 +562,11 @@ func TestDeleteUser(t *testing.T) {
rr := httptest.NewRecorder()
userHandler.deleteUser(rr, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.deleteUser(rr, req, userAuth)
res := rr.Result()
defer res.Body.Close()
@@ -608,7 +625,7 @@ func TestCurrentUser(t *testing.T) {
Issued: ptr("api"),
LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Owner)),
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.Owner)),
},
},
},
@@ -627,7 +644,7 @@ func TestCurrentUser(t *testing.T) {
Issued: ptr("api"),
LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)),
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.User)),
},
},
},
@@ -646,7 +663,7 @@ func TestCurrentUser(t *testing.T) {
Issued: ptr("api"),
LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.Admin)),
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.Admin)),
},
},
},
@@ -666,7 +683,7 @@ func TestCurrentUser(t *testing.T) {
LastLogin: ptr(time.Time{}),
Permissions: &api.UserPermissions{
IsRestricted: true,
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles.User)),
Modules: stringifyPermissionsKeys(mergeRolePermissions(roles2.User)),
},
},
},
@@ -682,7 +699,11 @@ func TestCurrentUser(t *testing.T) {
rr := httptest.NewRecorder()
userHandler.getCurrentUser(rr, req)
userAuth := &auth.UserAuth{
UserId: tc.requestAuth.UserId,
AccountId: existingAccountID,
}
userHandler.getCurrentUser(rr, req, userAuth)
res := rr.Result()
defer res.Body.Close()
@@ -702,8 +723,8 @@ func ptr[T any, PT *T](x T) PT {
return &x
}
func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
permissions := roles.Permissions{}
func mergeRolePermissions(role roles2.RolePermissions) roles2.Permissions {
permissions := roles2.Permissions{}
for k := range modules.All {
if rolePermissions, ok := role.Permissions[k]; ok {
@@ -716,7 +737,7 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
return permissions
}
func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[string]bool {
func stringifyPermissionsKeys(permissions roles2.Permissions) map[string]map[string]bool {
modules := make(map[string]map[string]bool)
for module, operations := range permissions {
modules[string(module)] = make(map[string]bool)
@@ -779,7 +800,7 @@ func TestApproveUserEndpoint(t *testing.T) {
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
router.HandleFunc("/users/{userId}/approve", permissions.WrapHandler(handler.approveUser)).Methods("POST")
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
require.NoError(t, err)
@@ -837,7 +858,7 @@ func TestRejectUserEndpoint(t *testing.T) {
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE")
router.HandleFunc("/users/{userId}/reject", permissions.WrapHandler(handler.rejectUser)).Methods("DELETE")
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
require.NoError(t, err)
@@ -928,7 +949,7 @@ func TestChangePasswordEndpoint(t *testing.T) {
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/password", handler.changePassword).Methods("PUT")
router.HandleFunc("/users/{userId}/password", permissions.WrapHandler(handler.changePassword)).Methods("PUT")
reqPath := "/users/" + tc.targetUserID + "/password"
req, err := http.NewRequest("PUT", reqPath, bytes.NewBufferString(tc.requestBody))
@@ -967,7 +988,7 @@ func TestChangePasswordEndpoint_WrongMethod(t *testing.T) {
req = nbcontext.SetUserAuthInRequest(req, userAuth)
rr := httptest.NewRecorder()
handler.changePassword(rr, req)
handler.changePassword(rr, req, &userAuth)
assert.Equal(t, http.StatusMethodNotAllowed, rr.Code)
}

View File

@@ -233,6 +233,71 @@ func Test_Accounts_Update(t *testing.T) {
}
}
func Test_Accounts_Update_CrossAccountAttack(t *testing.T) {
t.Run("Other user attempts to update testAccount via URL", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
body, err := json.Marshal(&api.AccountRequest{
Settings: api.AccountSettings{
PeerLoginExpirationEnabled: false,
PeerLoginExpiration: 86400,
},
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
// OtherUserId belongs to otherAccountId, but we target testAccountId in URL
req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/accounts/"+testing_tools.TestAccountId, testing_tools.OtherUserId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account update must be rejected")
})
}
func Test_Accounts_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, false},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, false},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Delete account", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/accounts/"+testing_tools.TestAccountId, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
})
}
}
func Test_Accounts_Delete_CrossAccountAttack(t *testing.T) {
t.Run("Other user attempts to delete testAccount via URL", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
// OtherUserId belongs to otherAccountId, but we target testAccountId in URL
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/accounts/"+testing_tools.TestAccountId, testing_tools.OtherUserId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account delete must be rejected")
})
}
func stringPointer(s string) *string {
return &s
}

View File

@@ -0,0 +1,445 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Records_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all records", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/zones/testZoneId/records", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.DNSRecord{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "sub.example.com", got[0].Name)
assert.Equal(t, api.DNSRecordTypeA, got[0].Type)
assert.Equal(t, "1.2.3.4", got[0].Content)
assert.Equal(t, 300, got[0].Ttl)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Records_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
recordId string
expectedStatus int
expectRecord bool
}{
{
name: "Get existing record",
zoneId: "testZoneId",
recordId: "testRecordId",
expectedStatus: http.StatusOK,
expectRecord: true,
},
{
name: "Get non-existing record",
zoneId: "testZoneId",
recordId: "nonExistingRecordId",
expectedStatus: http.StatusNotFound,
expectRecord: false,
},
{
name: "Get record from non-existing zone",
zoneId: "nonExistingZoneId",
recordId: "testRecordId",
expectedStatus: http.StatusNotFound,
expectRecord: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectRecord {
got := &api.DNSRecord{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, "testRecordId", got.Id)
assert.Equal(t, "sub.example.com", got.Name)
assert.Equal(t, api.DNSRecordTypeA, got.Type)
assert.Equal(t, "1.2.3.4", got.Content)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Records_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
requestBody *api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, record *api.DNSRecord)
}{
{
name: "Create A record",
zoneId: "testZoneId",
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "new.example.com",
Type: api.DNSRecordTypeA,
Content: "5.6.7.8",
Ttl: 600,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
t.Helper()
assert.NotEmpty(t, record.Id)
assert.Equal(t, "new.example.com", record.Name)
assert.Equal(t, api.DNSRecordTypeA, record.Type)
assert.Equal(t, "5.6.7.8", record.Content)
assert.Equal(t, 600, record.Ttl)
},
},
{
name: "Create CNAME record",
zoneId: "testZoneId",
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "alias.example.com",
Type: api.DNSRecordTypeCNAME,
Content: "target.example.com",
Ttl: 300,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
t.Helper()
assert.NotEmpty(t, record.Id)
assert.Equal(t, "alias.example.com", record.Name)
assert.Equal(t, api.DNSRecordTypeCNAME, record.Type)
assert.Equal(t, "target.example.com", record.Content)
},
},
{
name: "Create record with invalid content for A type",
zoneId: "testZoneId",
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "bad.example.com",
Type: api.DNSRecordTypeA,
Content: "not-an-ip",
Ttl: 300,
},
expectedStatus: http.StatusUnprocessableEntity,
},
{
name: "Create record in non-existing zone",
zoneId: "nonExistingZoneId",
requestBody: &api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "new.example.com",
Type: api.DNSRecordTypeA,
Content: "5.6.7.8",
Ttl: 600,
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
path := strings.Replace("/api/dns/zones/{zoneId}/records", "{zoneId}", tc.zoneId, 1)
req := testing_tools.BuildRequest(t, body, http.MethodPost, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.DNSRecord{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the created record directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbRecord := testing_tools.VerifyRecordInDB(t, db, got.Id)
assert.Equal(t, got.Name, dbRecord.Name)
assert.Equal(t, got.Content, dbRecord.Content)
assert.Equal(t, got.Ttl, dbRecord.TTL)
}
})
}
}
}
func Test_Records_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
recordId string
requestBody *api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, record *api.DNSRecord)
}{
{
name: "Update record content and TTL",
zoneId: "testZoneId",
recordId: "testRecordId",
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "sub.example.com",
Type: api.DNSRecordTypeA,
Content: "10.20.30.40",
Ttl: 600,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, record *api.DNSRecord) {
t.Helper()
assert.Equal(t, "sub.example.com", record.Name)
assert.Equal(t, "10.20.30.40", record.Content)
assert.Equal(t, 600, record.Ttl)
},
},
{
name: "Update non-existing record",
zoneId: "testZoneId",
recordId: "nonExistingRecordId",
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "sub.example.com",
Type: api.DNSRecordTypeA,
Content: "10.20.30.40",
Ttl: 600,
},
expectedStatus: http.StatusNotFound,
},
{
name: "Update record in non-existing zone",
zoneId: "nonExistingZoneId",
recordId: "testRecordId",
requestBody: &api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "sub.example.com",
Type: api.DNSRecordTypeA,
Content: "10.20.30.40",
Ttl: 600,
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
req := testing_tools.BuildRequest(t, body, http.MethodPut, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.DNSRecord{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the updated record directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbRecord := testing_tools.VerifyRecordInDB(t, db, tc.recordId)
assert.Equal(t, "10.20.30.40", dbRecord.Content)
assert.Equal(t, 600, dbRecord.TTL)
}
})
}
}
}
func Test_Records_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
recordId string
expectedStatus int
}{
{
name: "Delete existing record",
zoneId: "testZoneId",
recordId: "testRecordId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing record",
zoneId: "testZoneId",
recordId: "nonExistingRecordId",
expectedStatus: http.StatusNotFound,
},
{
name: "Delete record from non-existing zone",
zoneId: "nonExistingZoneId",
recordId: "testRecordId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
path := strings.Replace("/api/dns/zones/{zoneId}/records/{recordId}", "{zoneId}", tc.zoneId, 1)
path = strings.Replace(path, "{recordId}", tc.recordId, 1)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, path, user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
// Verify deletion in DB for successful deletes by privileged users
if tc.expectedStatus == http.StatusOK && user.expectResponse {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyRecordNotInDB(t, db, tc.recordId)
}
})
}
}
}

View File

@@ -0,0 +1,416 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Zones_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all zones", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/dns/zones", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Zone{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "Test Zone", got[0].Name)
assert.Equal(t, "example.com", got[0].Domain)
assert.Equal(t, true, got[0].Enabled)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Zones_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
expectedStatus int
expectZone bool
}{
{
name: "Get existing zone",
zoneId: "testZoneId",
expectedStatus: http.StatusOK,
expectZone: true,
},
{
name: "Get non-existing zone",
zoneId: "nonExistingZoneId",
expectedStatus: http.StatusNotFound,
expectZone: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectZone {
got := &api.Zone{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, "testZoneId", got.Id)
assert.Equal(t, "Test Zone", got.Name)
assert.Equal(t, "example.com", got.Domain)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_Zones_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
enabled := true
disabled := false
tt := []struct {
name string
requestBody *api.PostApiDnsZonesJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, zone *api.Zone)
}{
{
name: "Create zone with valid data",
requestBody: &api.PostApiDnsZonesJSONRequestBody{
Name: "New Zone",
Domain: "newzone.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, zone *api.Zone) {
t.Helper()
assert.NotEmpty(t, zone.Id)
assert.Equal(t, "New Zone", zone.Name)
assert.Equal(t, "newzone.com", zone.Domain)
assert.Equal(t, true, zone.Enabled)
assert.Equal(t, false, zone.EnableSearchDomain)
assert.Equal(t, 1, len(zone.DistributionGroups))
},
},
{
name: "Create zone with search domain enabled",
requestBody: &api.PostApiDnsZonesJSONRequestBody{
Name: "Search Zone",
Domain: "search.example.com",
Enabled: &enabled,
EnableSearchDomain: true,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, zone *api.Zone) {
t.Helper()
assert.NotEmpty(t, zone.Id)
assert.Equal(t, "Search Zone", zone.Name)
assert.Equal(t, true, zone.EnableSearchDomain)
},
},
{
name: "Create disabled zone",
requestBody: &api.PostApiDnsZonesJSONRequestBody{
Name: "Disabled Zone",
Domain: "disabled.example.com",
Enabled: &disabled,
EnableSearchDomain: false,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, zone *api.Zone) {
t.Helper()
assert.NotEmpty(t, zone.Id)
assert.Equal(t, false, zone.Enabled)
},
},
{
name: "Create zone with empty distribution groups",
requestBody: &api.PostApiDnsZonesJSONRequestBody{
Name: "No Groups Zone",
Domain: "nogroups.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{},
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/dns/zones", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Zone{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the created zone directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbZone := testing_tools.VerifyZoneInDB(t, db, got.Id)
assert.Equal(t, got.Name, dbZone.Name)
assert.Equal(t, got.Domain, dbZone.Domain)
assert.Equal(t, got.Enabled, dbZone.Enabled)
assert.Equal(t, got.EnableSearchDomain, dbZone.EnableSearchDomain)
}
})
}
}
}
func Test_Zones_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
enabled := true
tt := []struct {
name string
zoneId string
requestBody *api.PutApiDnsZonesZoneIdJSONRequestBody
expectedStatus int
verifyResponse func(t *testing.T, zone *api.Zone)
}{
{
name: "Update zone name and domain",
zoneId: "testZoneId",
requestBody: &api.PutApiDnsZonesZoneIdJSONRequestBody{
Name: "Updated Zone",
Domain: "updated.example.com",
Enabled: &enabled,
EnableSearchDomain: true,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, zone *api.Zone) {
t.Helper()
assert.Equal(t, "Updated Zone", zone.Name)
assert.Equal(t, "updated.example.com", zone.Domain)
assert.Equal(t, true, zone.EnableSearchDomain)
},
},
{
name: "Update non-existing zone",
zoneId: "nonExistingZoneId",
requestBody: &api.PutApiDnsZonesZoneIdJSONRequestBody{
Name: "Whatever",
Domain: "whatever.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{testing_tools.TestGroupId},
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.Zone{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
// Verify the updated zone directly in the DB
db := testing_tools.GetDB(t, am.GetStore())
dbZone := testing_tools.VerifyZoneInDB(t, db, tc.zoneId)
assert.Equal(t, "Updated Zone", dbZone.Name)
assert.Equal(t, "updated.example.com", dbZone.Domain)
assert.Equal(t, true, dbZone.Enabled)
assert.Equal(t, true, dbZone.EnableSearchDomain)
}
})
}
}
}
func Test_Zones_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
zoneId string
expectedStatus int
}{
{
name: "Delete existing zone",
zoneId: "testZoneId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing zone",
zoneId: "nonExistingZoneId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/dns_zones.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/dns/zones/{zoneId}", "{zoneId}", tc.zoneId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
// Verify deletion in DB for successful deletes by privileged users
if tc.expectedStatus == http.StatusOK && user.expectResponse {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyZoneNotInDB(t, db, tc.zoneId)
}
})
}
}
}

View File

@@ -78,6 +78,68 @@ func Test_Events_GetAll(t *testing.T) {
}
}
func Test_Events_GetAll_Audit(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all audit events", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, false)
// First, perform a mutation to generate an event (create a group as admin)
groupBody, err := json.Marshal(&api.GroupRequest{Name: "auditTestGroup"})
if err != nil {
t.Fatalf("Failed to marshal group request: %v", err)
}
createReq := testing_tools.BuildRequest(t, groupBody, http.MethodPost, "/api/groups", testing_tools.TestAdminId)
createRecorder := httptest.NewRecorder()
apiHandler.ServeHTTP(createRecorder, createReq)
assert.Equal(t, http.StatusOK, createRecorder.Code, "Failed to create group to generate event")
// Now query audit events
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events/audit", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Event{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.GreaterOrEqual(t, len(got), 1, "Expected at least one event after creating a group")
// Verify the group creation event exists
found := false
for _, event := range got {
if event.ActivityCode == "group.add" {
found = true
assert.Equal(t, testing_tools.TestAdminId, event.InitiatorId)
assert.Equal(t, "Group created", event.Activity)
break
}
}
assert.True(t, found, "Expected to find a group.add event")
})
}
}
func Test_Events_GetAll_Empty(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/events.sql", nil, true)

View File

@@ -0,0 +1,118 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Geolocations_GetAllCountries(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all countries", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/locations/countries", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.Country{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got))
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Geolocations_GetCitiesByCountry(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get cities by country", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/locations/countries/{country}/cities", "{country}", "US", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.City{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got))
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_Geolocations_GetCitiesByCountry_InvalidCode(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/locations/countries/{country}/cities", "{country}", "INVALID", 1), testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, http.StatusUnprocessableEntity, true)
}

View File

@@ -216,7 +216,6 @@ func Test_Groups_Create(t *testing.T) {
}
tc.verifyResponse(t, got)
// Verify group exists in DB
db := testing_tools.GetDB(t, am.GetStore())
dbGroup := testing_tools.VerifyGroupInDB(t, db, got.Id)
assert.Equal(t, tc.requestBody.Name, dbGroup.Name)

View File

@@ -0,0 +1,295 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_IdentityProviders_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all identity providers", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/identity-providers", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.IdentityProvider{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
// The embedded IdP manager is not initialized in the test environment,
// so GetIdentityProviders returns an empty list.
assert.Equal(t, 0, len(got))
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_IdentityProviders_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
idpId string
expectedStatus int
}{
{
name: "Get existing identity provider",
idpId: "testIdpId",
expectedStatus: http.StatusInternalServerError,
},
{
name: "Get non-existing identity provider",
idpId: "nonExistingIdpId",
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}
func Test_IdentityProviders_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
requestBody *api.PostApiIdentityProvidersJSONRequestBody
expectedStatus int
}{
{
name: "Create identity provider with valid data",
requestBody: &api.PostApiIdentityProvidersJSONRequestBody{
Type: api.IdentityProviderTypeGoogle,
Name: "New IDP",
ClientId: "newClientId",
ClientSecret: "newClientSecret",
},
// Validation passes but the embedded IdP manager is not initialized,
// so the operation returns an internal server error.
expectedStatus: http.StatusInternalServerError,
},
{
name: "Create identity provider with invalid issuer",
requestBody: &api.PostApiIdentityProvidersJSONRequestBody{
Type: api.IdentityProviderTypeOidc,
Name: "Invalid IDP",
Issuer: "not-a-url",
ClientId: "clientId",
ClientSecret: "clientSecret",
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/identity-providers", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}
func Test_IdentityProviders_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
idpId string
requestBody *api.PutApiIdentityProvidersIdpIdJSONRequestBody
expectedStatus int
}{
{
name: "Update existing identity provider",
idpId: "testIdpId",
requestBody: &api.PutApiIdentityProvidersIdpIdJSONRequestBody{
Type: api.IdentityProviderTypeGoogle,
Name: "Updated IDP",
ClientId: "updatedClientId",
ClientSecret: "updatedClientSecret",
},
// Validation passes but the embedded IdP manager is not initialized,
// so the operation returns an internal server error.
expectedStatus: http.StatusInternalServerError,
},
{
name: "Update non-existing identity provider",
idpId: "nonExistingIdpId",
requestBody: &api.PutApiIdentityProvidersIdpIdJSONRequestBody{
Type: api.IdentityProviderTypeGoogle,
Name: "Updated IDP",
ClientId: "updatedClientId",
ClientSecret: "updatedClientSecret",
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}
func Test_IdentityProviders_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
idpId string
expectedStatus int
}{
{
name: "Delete existing identity provider",
idpId: "testIdpId",
expectedStatus: http.StatusInternalServerError,
},
{
name: "Delete non-existing identity provider",
idpId: "nonExistingIdpId",
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/identity_providers.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/identity-providers/{idpId}", "{idpId}", tc.idpId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}

View File

@@ -0,0 +1,183 @@
//go:build integration
package integration
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// Test_Instance_GetStatus tests the unauthenticated GET /api/instance endpoint.
// This endpoint bypasses auth middleware. With nil idpManager (no embedded IDP),
// SetupRequired should be false.
func Test_Instance_GetStatus(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
// The /api/instance endpoint is unauthenticated (bypass path).
// We still pass a token via BuildRequest but the bypass middleware skips auth.
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/instance", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
assert.Equal(t, http.StatusOK, recorder.Code, "Expected 200 OK for instance status endpoint, got %d: %s", recorder.Code, string(content))
got := &api.InstanceStatus{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
// With nil idpManager (no embedded IDP configured), setup is not required.
assert.Equal(t, false, got.SetupRequired, "Expected SetupRequired to be false when embedded IDP is not configured")
}
// Test_Instance_GetStatus_Unauthenticated verifies the endpoint works without any
// valid user token, since it is on the bypass path.
func Test_Instance_GetStatus_Unauthenticated(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
// Use an invalid token to confirm the bypass middleware skips auth
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/instance", testing_tools.InvalidToken)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
assert.Equal(t, http.StatusOK, recorder.Code, "Expected 200 OK for unauthenticated instance status endpoint, got %d: %s", recorder.Code, string(content))
got := &api.InstanceStatus{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, false, got.SetupRequired)
}
// Test_Instance_GetVersionInfo tests the authenticated GET /api/instance/version endpoint.
func Test_Instance_GetVersionInfo(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get version info", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/instance/version", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := &api.InstanceVersionInfo{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.NotEmpty(t, got.ManagementCurrentVersion, "Expected non-empty current version")
})
}
}
// Test_Instance_Setup tests the unauthenticated POST /api/setup endpoint.
// Since embedded IDP is not configured in the test environment, the setup
// endpoint should return an error (500 Internal Server Error) because the
// instance manager's CreateOwnerUser returns "embedded IDP is not enabled".
func Test_Instance_Setup(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
body, err := json.Marshal(&api.SetupRequest{
Email: "admin@test.com",
Password: "securepassword123",
Name: "Admin User",
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
// The /api/setup endpoint is unauthenticated (bypass path).
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/setup", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
// Without embedded IDP, CreateOwnerUser returns a plain error (not a status error),
// which the handler maps to 500 Internal Server Error.
assert.Equal(t, http.StatusInternalServerError, recorder.Code,
"Expected 500 when embedded IDP is not configured, got %d: %s", recorder.Code, string(content))
}
// Test_Instance_Setup_Unauthenticated verifies the setup endpoint works (reaches
// the handler) even with an invalid token, since it is on the bypass path.
func Test_Instance_Setup_Unauthenticated(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
body, err := json.Marshal(&api.SetupRequest{
Email: "admin@test.com",
Password: "securepassword123",
Name: "Admin User",
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/setup", testing_tools.InvalidToken)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
// Even with an invalid token, the bypass middleware skips auth and the handler runs.
// Without embedded IDP, it returns 500.
assert.Equal(t, http.StatusInternalServerError, recorder.Code,
"Expected 500 when embedded IDP is not configured, got %d: %s", recorder.Code, string(content))
}

View File

@@ -0,0 +1,154 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// Note: The integration test infrastructure does not configure an embedded IDP,
// so actual invite operations will return PreconditionFailed (412) for authorized users.
// These tests verify that the permissions layer correctly denies regular users
// before the handler logic is reached.
func Test_Invites_List(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - List invites", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users/invites", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get PreconditionFailed (no embedded IDP configured)
// Unauthorized users get rejected by the permissions middleware
testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse)
})
}
}
func Test_Invites_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Create invite", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
body, err := json.Marshal(&api.UserInviteCreateRequest{
Email: "newuser@test.com",
Name: "New User",
Role: "user",
AutoGroups: []string{},
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/users/invites", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get PreconditionFailed (no embedded IDP configured)
// Unauthorized users get rejected by the permissions middleware
testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse)
})
}
}
func Test_Invites_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Delete invite", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/invites/{inviteId}", "{inviteId}", "someInviteId", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get PreconditionFailed (no embedded IDP configured)
// Unauthorized users get rejected by the permissions middleware
testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse)
})
}
}
func Test_Invites_Regenerate(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Regenerate invite", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodPost, strings.Replace("/api/users/invites/{inviteId}/regenerate", "{inviteId}", "someInviteId", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get PreconditionFailed (no embedded IDP configured)
// Unauthorized users get rejected by the permissions middleware
testing_tools.ReadResponse(t, recorder, http.StatusPreconditionFailed, user.expectResponse)
})
}
}

View File

@@ -0,0 +1,372 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_PostureChecks_GetAll(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all posture checks", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/posture-checks", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []*api.PostureCheck{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 1, len(got))
assert.Equal(t, "testPostureCheckId", got[0].Id)
assert.Equal(t, "NetBird Version Check", got[0].Name)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_PostureChecks_GetById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
postureCheckId string
expectedStatus int
expectCheck bool
}{
{
name: "Get existing posture check",
postureCheckId: "testPostureCheckId",
expectedStatus: http.StatusOK,
expectCheck: true,
},
{
name: "Get non-existing posture check",
postureCheckId: "nonExistingPostureCheckId",
expectedStatus: http.StatusNotFound,
expectCheck: false,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.expectCheck {
got := &api.PostureCheck{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, "testPostureCheckId", got.Id)
assert.Equal(t, "NetBird Version Check", got.Name)
}
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
}
func Test_PostureChecks_Create(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
minVersion := "0.32.0"
tt := []struct {
name string
requestBody *api.PostureCheckUpdate
expectedStatus int
verifyResponse func(t *testing.T, check *api.PostureCheck)
}{
{
name: "Create posture check with NB version",
requestBody: &api.PostureCheckUpdate{
Name: "New Version Check",
Description: "check for new version",
Checks: &api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: minVersion,
},
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, check *api.PostureCheck) {
t.Helper()
assert.NotEmpty(t, check.Id)
assert.Equal(t, "New Version Check", check.Name)
assert.NotNil(t, check.Checks.NbVersionCheck)
assert.Equal(t, minVersion, check.Checks.NbVersionCheck.MinVersion)
},
},
{
name: "Create posture check with empty name",
requestBody: &api.PostureCheckUpdate{
Name: "",
Checks: &api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: "0.32.0",
},
},
},
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/posture-checks", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.PostureCheck{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
db := testing_tools.GetDB(t, am.GetStore())
dbCheck := testing_tools.VerifyPostureCheckInDB(t, db, got.Id)
assert.Equal(t, got.Name, dbCheck.Name)
}
})
}
}
}
func Test_PostureChecks_Update(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
minVersion := "0.33.0"
tt := []struct {
name string
postureCheckId string
requestBody *api.PostureCheckUpdate
expectedStatus int
verifyResponse func(t *testing.T, check *api.PostureCheck)
}{
{
name: "Update posture check name and version",
postureCheckId: "testPostureCheckId",
requestBody: &api.PostureCheckUpdate{
Name: "Updated Version Check",
Description: "updated description",
Checks: &api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: minVersion,
},
},
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, check *api.PostureCheck) {
t.Helper()
assert.Equal(t, "testPostureCheckId", check.Id)
assert.Equal(t, "Updated Version Check", check.Name)
},
},
{
name: "Update non-existing posture check",
postureCheckId: "nonExistingPostureCheckId",
requestBody: &api.PostureCheckUpdate{
Name: "whatever",
Checks: &api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: "0.33.0",
},
},
},
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false)
body, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPut, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if !expectResponse {
return
}
if tc.verifyResponse != nil {
got := &api.PostureCheck{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.verifyResponse(t, got)
db := testing_tools.GetDB(t, am.GetStore())
dbCheck := testing_tools.VerifyPostureCheckInDB(t, db, tc.postureCheckId)
assert.Equal(t, "Updated Version Check", dbCheck.Name)
}
})
}
}
}
func Test_PostureChecks_Delete(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
postureCheckId string
expectedStatus int
}{
{
name: "Delete existing posture check",
postureCheckId: "testPostureCheckId",
expectedStatus: http.StatusOK,
},
{
name: "Delete non-existing posture check",
postureCheckId: "nonExistingPostureCheckId",
expectedStatus: http.StatusNotFound,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/posture_checks.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/posture-checks/{postureCheckId}", "{postureCheckId}", tc.postureCheckId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if tc.expectedStatus == http.StatusOK && user.expectResponse {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyPostureCheckNotInDB(t, db, tc.postureCheckId)
}
})
}
}
}

View File

@@ -0,0 +1,306 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_ReverseProxy_GetClusters(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get clusters", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/reverse-proxies/clusters", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.ProxyCluster{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got), "Expected empty clusters list when no proxy is connected")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_ReverseProxy_GetAllServices(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all services", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/reverse-proxies/services", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []*api.Service{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got), "Expected empty services list with no services in DB")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_ReverseProxy_CreateService(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Create service", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
// Creating a service requires a valid domain entry. Without one, authorized
// users will receive a validation error. This test verifies that the
// permissions layer correctly rejects unauthorized users before handler
// logic runs, and that authorized users reach the handler (even if the
// handler returns an error due to missing domain).
body, err := json.Marshal(&api.ServiceRequest{
Name: "test-service",
Domain: "nonexistent.example.com",
Enabled: true,
})
if err != nil {
t.Fatalf("Failed to marshal request body: %v", err)
}
req := testing_tools.BuildRequest(t, body, http.MethodPost, "/api/reverse-proxies/services", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users reach the handler but get a validation error (no valid domain).
// Unauthorized users are rejected by the permissions middleware.
testing_tools.ReadResponse(t, recorder, http.StatusUnprocessableEntity, user.expectResponse)
})
}
}
func Test_ReverseProxy_GetServiceById(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get service by ID (non-existing)", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, strings.Replace("/api/reverse-proxies/services/{serviceId}", "{serviceId}", "nonExistingServiceId", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get NotFound for a non-existing service.
// Unauthorized users are rejected by the permissions middleware.
testing_tools.ReadResponse(t, recorder, http.StatusNotFound, user.expectResponse)
})
}
}
func Test_ReverseProxy_DeleteService(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Delete service (non-existing)", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/reverse-proxies/services/{serviceId}", "{serviceId}", "nonExistingServiceId", 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
// Authorized users get NotFound for a non-existing service.
// Unauthorized users are rejected by the permissions middleware.
testing_tools.ReadResponse(t, recorder, http.StatusNotFound, user.expectResponse)
})
}
}
func Test_ReverseProxy_GetAllDomains(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get all domains", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/reverse-proxies/domains", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := []api.ReverseProxyDomain{}
if err := json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got), "Expected empty domains list with no domains in DB")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}
func Test_ReverseProxy_GetAccessLogs(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get proxy access logs", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/accounts.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/events/proxy", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := &api.ProxyAccessLogsResponse{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, 0, len(got.Data), "Expected empty access logs data")
assert.Equal(t, 0, got.TotalRecords, "Expected zero total records")
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}

View File

@@ -0,0 +1,124 @@
//go:build integration
package integration
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
)
func Test_Users_Approve(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
targetUserId string
expectedStatus int
}{
{
name: "Approve pending user",
targetUserId: "pendingUserId",
expectedStatus: http.StatusOK,
},
{
name: "Approve non-existing user",
targetUserId: "nonExistingUserId",
expectedStatus: http.StatusNotFound,
},
{
name: "Approve already active user",
targetUserId: testing_tools.TestUserId,
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_approve_reject.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodPost, strings.Replace("/api/users/{userId}/approve", "{userId}", tc.targetUserId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
})
}
}
}
func Test_Users_Reject(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, false},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, false},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
tt := []struct {
name string
targetUserId string
expectedStatus int
}{
{
name: "Reject pending user",
targetUserId: "pendingUserId",
expectedStatus: http.StatusOK,
},
{
name: "Reject non-existing user",
targetUserId: "nonExistingUserId",
expectedStatus: http.StatusNotFound,
},
{
name: "Reject non-pending user",
targetUserId: testing_tools.TestUserId,
expectedStatus: http.StatusUnprocessableEntity,
},
}
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_approve_reject.sql", nil, false)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, strings.Replace("/api/users/{userId}/reject", "{userId}", tc.targetUserId, 1), user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
_, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse)
if expectResponse && tc.expectedStatus == http.StatusOK {
db := testing_tools.GetDB(t, am.GetStore())
testing_tools.VerifyUserNotInDB(t, db, tc.targetUserId)
}
})
}
}
}

View File

@@ -0,0 +1,64 @@
//go:build integration
package integration
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_Users_GetCurrent(t *testing.T) {
users := []struct {
name string
userId string
expectResponse bool
}{
{"Regular user", testing_tools.TestUserId, true},
{"Admin user", testing_tools.TestAdminId, true},
{"Owner user", testing_tools.TestOwnerId, true},
{"Regular service user", testing_tools.TestServiceUserId, true},
{"Admin service user", testing_tools.TestServiceAdminId, true},
{"Blocked user", testing_tools.BlockedUserId, false},
{"Other user", testing_tools.OtherUserId, false},
{"Invalid token", testing_tools.InvalidToken, false},
}
for _, user := range users {
t.Run(user.name+" - Get current user", func(t *testing.T) {
apiHandler, _, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, http.MethodGet, "/api/users/current", user.userId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
content, expectResponse := testing_tools.ReadResponse(t, recorder, http.StatusOK, user.expectResponse)
if !expectResponse {
return
}
got := &api.User{}
if err := json.Unmarshal(content, got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, user.userId, got.Id)
assert.NotNil(t, got.IsCurrent)
assert.True(t, *got.IsCurrent)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}
}

View File

@@ -637,6 +637,38 @@ func Test_PATs_Create(t *testing.T) {
}
}
func Test_Users_Update_CrossAccountAttack(t *testing.T) {
t.Run("Admin attempts to update user from other account", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
body, _ := json.Marshal(&api.UserRequest{
Role: "user",
AutoGroups: []string{},
IsBlocked: true,
})
// TestAdminId belongs to testAccountId, but targets otherUserId which belongs to otherAccountId
req := testing_tools.BuildRequest(t, body, http.MethodPut, "/api/users/otherUserId", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account user update must be rejected")
})
}
func Test_Users_Delete_CrossAccountAttack(t *testing.T) {
t.Run("Admin attempts to delete service user from other account", func(t *testing.T) {
apiHandler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../testdata/users_integration.sql", nil, false)
// TestAdminId belongs to testAccountId, but targets otherServiceUserId which belongs to otherAccountId
req := testing_tools.BuildRequest(t, []byte{}, http.MethodDelete, "/api/users/otherServiceUserId", testing_tools.TestAdminId)
recorder := httptest.NewRecorder()
apiHandler.ServeHTTP(recorder, req)
assert.NotEqual(t, http.StatusOK, recorder.Code, "cross-account user delete must be rejected")
})
}
func Test_PATs_Delete(t *testing.T) {
users := []struct {
name string

View File

@@ -5,6 +5,7 @@ CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`i
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',0,'otherNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
@@ -15,4 +16,4 @@ INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NU
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);

View File

@@ -6,6 +6,7 @@ CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`i
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',0,'otherNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
@@ -15,7 +16,7 @@ INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,N
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO name_server_groups VALUES('testNSGroupId','testAccountId','testNSGroup','test nameserver group','[{"IP":"1.1.1.1","NSType":1,"Port":53}]','["testGroupId"]',0,'["example.com"]',1,0);

View File

@@ -0,0 +1,27 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `zones` (`id` text,`account_id` text,`name` text,`domain` text,`enabled` numeric,`enable_search_domain` numeric,`distribution_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_zones` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `records` (`id` text,`account_id` text,`zone_id` text,`name` text,`type` text,`content` text,`ttl` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_zones_records` FOREIGN KEY (`zone_id`) REFERENCES `zones`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',0,'otherNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO name_server_groups VALUES('testNSGroupId','testAccountId','testNSGroup','test nameserver group','[{"IP":"1.1.1.1","NSType":1,"Port":53}]','["testGroupId"]',0,'["example.com"]',1,0);
INSERT INTO zones VALUES('testZoneId','testAccountId','Test Zone','example.com',1,0,'["testGroupId"]');
INSERT INTO records VALUES('testRecordId','testAccountId','testZoneId','sub.example.com','A','1.2.3.4',300);

View File

@@ -5,6 +5,7 @@ CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` t
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',0,'otherNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
@@ -14,5 +15,5 @@ INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,N
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);

View File

@@ -5,6 +5,7 @@ CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` t
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',0,'otherNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
@@ -15,5 +16,5 @@ INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NU
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('allGroupId','testAccountId','All','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);

View File

@@ -0,0 +1,21 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `identity_providers` (`id` text,`account_id` text,`type` text,`name` text,`issuer` text,`client_id` text,`client_secret` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_identity_providers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',0,'otherNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','["testPeerId"]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,0,NULL,'["testGroupId"]',1,0);
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO identity_providers VALUES('testIdpId','testAccountId','oidc','Test IDP','https://issuer.example.com','client123','secret456');

View File

@@ -8,6 +8,7 @@ CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`p
CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`name` text,`description` text,`type` text,`domain` text,`prefix` text,`enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_resources` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO accounts VALUES('otherAccountId','','2024-10-02 16:01:38.000000000+00:00','other.com','private',0,'otherNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,'');

Some files were not shown because too many files have changed in this diff Show More