mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 08:54:11 -04:00
Add user invite link feature for embedded IdP (#5157)
This commit is contained in:
@@ -30,6 +30,12 @@ type Manager interface {
|
||||
autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
|
||||
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
||||
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
|
||||
AcceptUserInvite(ctx context.Context, token, password string) error
|
||||
RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
|
||||
GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error)
|
||||
ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
|
||||
DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error
|
||||
|
||||
@@ -199,6 +199,11 @@ const (
|
||||
|
||||
UserPasswordChanged Activity = 103
|
||||
|
||||
UserInviteLinkCreated Activity = 104
|
||||
UserInviteLinkAccepted Activity = 105
|
||||
UserInviteLinkRegenerated Activity = 106
|
||||
UserInviteLinkDeleted Activity = 107
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -327,6 +332,11 @@ var activityMap = map[Activity]Code{
|
||||
JobCreatedByUser: {"Create Job for peer", "peer.job.create"},
|
||||
|
||||
UserPasswordChanged: {"User password changed", "user.password.change"},
|
||||
|
||||
UserInviteLinkCreated: {"User invite link created", "user.invite.link.create"},
|
||||
UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"},
|
||||
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
|
||||
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -68,6 +68,13 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
if err := bypass.AddBypassPath("/api/setup"); err != nil {
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
// Public invite endpoints (tokens start with nbi_)
|
||||
if err := bypass.AddBypassPath("/api/users/invites/nbi_*"); err != nil {
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil {
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
|
||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||
@@ -132,6 +139,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController)
|
||||
users.AddEndpoints(accountManager, router)
|
||||
users.AddInvitesEndpoints(accountManager, router)
|
||||
users.AddPublicInvitesEndpoints(accountManager, router)
|
||||
setup_keys.AddEndpoints(accountManager, router)
|
||||
policies.AddEndpoints(accountManager, LocationManager, router)
|
||||
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
|
||||
@@ -145,6 +154,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
recordsManager.RegisterEndpoints(router, rManager)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router)
|
||||
|
||||
// Mount embedded IdP handler at /oauth2 path if configured
|
||||
if embeddedIdpEnabled {
|
||||
|
||||
@@ -28,6 +28,15 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// AddVersionEndpoint registers the authenticated version endpoint.
|
||||
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
instanceManager: instanceManager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// getInstanceStatus returns the instance status including whether setup is required.
|
||||
// This endpoint is unauthenticated.
|
||||
func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -65,3 +74,29 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
|
||||
Email: userData.Email,
|
||||
})
|
||||
}
|
||||
|
||||
// getVersionInfo returns version information for NetBird components.
|
||||
// This endpoint requires authentication.
|
||||
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
|
||||
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)
|
||||
util.WriteErrorResponse("failed to get version info", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := api.InstanceVersionInfo{
|
||||
ManagementCurrentVersion: versionInfo.CurrentVersion,
|
||||
ManagementUpdateAvailable: versionInfo.ManagementUpdateAvailable,
|
||||
}
|
||||
|
||||
if versionInfo.DashboardVersion != "" {
|
||||
resp.DashboardAvailableVersion = &versionInfo.DashboardVersion
|
||||
}
|
||||
|
||||
if versionInfo.ManagementVersion != "" {
|
||||
resp.ManagementAvailableVersion = &versionInfo.ManagementVersion
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ type mockInstanceManager struct {
|
||||
isSetupRequired bool
|
||||
isSetupRequiredFn func(ctx context.Context) (bool, error)
|
||||
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error)
|
||||
}
|
||||
|
||||
func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) {
|
||||
@@ -66,6 +67,18 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) {
|
||||
if m.getVersionInfoFn != nil {
|
||||
return m.getVersionInfoFn(ctx)
|
||||
}
|
||||
return &nbinstance.VersionInfo{
|
||||
CurrentVersion: "0.34.0",
|
||||
DashboardVersion: "2.0.0",
|
||||
ManagementVersion: "0.35.0",
|
||||
ManagementUpdateAvailable: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
|
||||
|
||||
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
|
||||
@@ -279,3 +292,44 @@ func TestSetup_ManagerError(t *testing.T) {
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
func TestGetVersionInfo_Success(t *testing.T) {
|
||||
manager := &mockInstanceManager{}
|
||||
router := mux.NewRouter()
|
||||
AddVersionEndpoint(manager, router)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response api.InstanceVersionInfo
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "0.34.0", response.ManagementCurrentVersion)
|
||||
assert.NotNil(t, response.DashboardAvailableVersion)
|
||||
assert.Equal(t, "2.0.0", *response.DashboardAvailableVersion)
|
||||
assert.NotNil(t, response.ManagementAvailableVersion)
|
||||
assert.Equal(t, "0.35.0", *response.ManagementAvailableVersion)
|
||||
assert.True(t, response.ManagementUpdateAvailable)
|
||||
}
|
||||
|
||||
func TestGetVersionInfo_Error(t *testing.T) {
|
||||
manager := &mockInstanceManager{
|
||||
getVersionInfoFn: func(ctx context.Context) (*nbinstance.VersionInfo, error) {
|
||||
return nil, errors.New("failed to fetch versions")
|
||||
},
|
||||
}
|
||||
router := mux.NewRouter()
|
||||
AddVersionEndpoint(manager, router)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
263
management/server/http/handlers/users/invites_handler.go
Normal file
263
management/server/http/handlers/users/invites_handler.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// publicInviteRateLimiter limits public invite requests by IP address to prevent brute-force attacks
|
||||
var publicInviteRateLimiter = middleware.NewAPIRateLimiter(&middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: 10, // 10 attempts per minute per IP
|
||||
Burst: 5, // Allow burst of 5 requests
|
||||
CleanupInterval: 10 * time.Minute,
|
||||
LimiterTTL: 30 * time.Minute,
|
||||
})
|
||||
|
||||
// toUserInviteResponse converts a UserInvite to an API response.
|
||||
func toUserInviteResponse(invite *types.UserInvite) api.UserInvite {
|
||||
autoGroups := invite.UserInfo.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
}
|
||||
var inviteLink *string
|
||||
if invite.InviteToken != "" {
|
||||
inviteLink = &invite.InviteToken
|
||||
}
|
||||
return api.UserInvite{
|
||||
Id: invite.UserInfo.ID,
|
||||
Email: invite.UserInfo.Email,
|
||||
Name: invite.UserInfo.Name,
|
||||
Role: invite.UserInfo.Role,
|
||||
AutoGroups: autoGroups,
|
||||
ExpiresAt: invite.InviteExpiresAt.UTC(),
|
||||
CreatedAt: invite.InviteCreatedAt.UTC(),
|
||||
Expired: time.Now().After(invite.InviteExpiresAt),
|
||||
InviteToken: inviteLink,
|
||||
}
|
||||
}
|
||||
|
||||
// invitesHandler handles user invite operations
|
||||
type invitesHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
// AddInvitesEndpoints registers invite-related endpoints
|
||||
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
h := &invitesHandler{accountManager: accountManager}
|
||||
|
||||
// Authenticated endpoints (require admin)
|
||||
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
|
||||
func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
h := &invitesHandler{accountManager: accountManager}
|
||||
|
||||
// Create a subrouter for public invite endpoints with rate limiting middleware
|
||||
publicRouter := router.PathPrefix("/users/invites").Subrouter()
|
||||
publicRouter.Use(publicInviteRateLimiter.Middleware)
|
||||
|
||||
// Public endpoints (no auth required, protected by token and rate limited)
|
||||
publicRouter.HandleFunc("/{token}", h.getInviteInfo).Methods("GET", "OPTIONS")
|
||||
publicRouter.HandleFunc("/{token}/accept", h.acceptInvite).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// listInvites handles GET /api/users/invites
|
||||
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := make([]api.UserInvite, 0, len(invites))
|
||||
for _, invite := range invites {
|
||||
resp = append(resp, toUserInviteResponse(invite))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
// createInvite handles POST /api/users/invites
|
||||
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.UserInviteCreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
invite := &types.UserInfo{
|
||||
Email: req.Email,
|
||||
Name: req.Name,
|
||||
Role: req.Role,
|
||||
AutoGroups: req.AutoGroups,
|
||||
}
|
||||
|
||||
expiresIn := 0
|
||||
if req.ExpiresIn != nil {
|
||||
expiresIn = *req.ExpiresIn
|
||||
}
|
||||
|
||||
result, err := h.accountManager.CreateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, invite, expiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
result.InviteCreatedAt = time.Now().UTC()
|
||||
resp := toUserInviteResponse(result)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
// getInviteInfo handles GET /api/users/invites/{token}
|
||||
func (h *invitesHandler) getInviteInfo(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
vars := mux.Vars(r)
|
||||
token := vars["token"]
|
||||
if token == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := h.accountManager.GetUserInviteInfo(r.Context(), token)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := info.ExpiresAt.UTC()
|
||||
util.WriteJSONObject(r.Context(), w, &api.UserInviteInfo{
|
||||
Email: info.Email,
|
||||
Name: info.Name,
|
||||
ExpiresAt: expiresAt,
|
||||
Valid: info.Valid,
|
||||
InvitedBy: info.InvitedBy,
|
||||
})
|
||||
}
|
||||
|
||||
// acceptInvite handles POST /api/users/invites/{token}/accept
|
||||
func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
vars := mux.Vars(r)
|
||||
token := vars["token"]
|
||||
if token == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.UserInviteAcceptRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.AcceptUserInvite(r.Context(), token, req.Password)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, &api.UserInviteAcceptResponse{Success: true})
|
||||
}
|
||||
|
||||
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
|
||||
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.UserInviteRegenerateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
// Allow empty body (io.EOF) - expiresIn is optional
|
||||
if !errors.Is(err, io.EOF) {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
expiresIn := 0
|
||||
if req.ExpiresIn != nil {
|
||||
expiresIn = *req.ExpiresIn
|
||||
}
|
||||
|
||||
result, err := h.accountManager.RegenerateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID, expiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := result.InviteExpiresAt.UTC()
|
||||
util.WriteJSONObject(r.Context(), w, &api.UserInviteRegenerateResponse{
|
||||
InviteToken: result.InviteToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
// deleteInvite handles DELETE /api/users/invites/{inviteId}
|
||||
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
642
management/server/http/handlers/users/invites_handler_test.go
Normal file
642
management/server/http/handlers/users/invites_handler_test.go
Normal file
@@ -0,0 +1,642 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "test-account-id"
|
||||
testUserID = "test-user-id"
|
||||
testInviteID = "test-invite-id"
|
||||
testInviteToken = "nbi_testtoken123456789012345678"
|
||||
testEmail = "invite@example.com"
|
||||
testName = "Test User"
|
||||
)
|
||||
|
||||
func setupInvitesTestHandler(am *mock_server.MockAccountManager) *invitesHandler {
|
||||
return &invitesHandler{
|
||||
accountManager: am,
|
||||
}
|
||||
}
|
||||
|
||||
func TestListInvites(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
testInvites := []*types.UserInvite{
|
||||
{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "invite-1",
|
||||
Email: "user1@example.com",
|
||||
Name: "User One",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1"},
|
||||
},
|
||||
InviteExpiresAt: now.Add(24 * time.Hour),
|
||||
InviteCreatedAt: now,
|
||||
},
|
||||
{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: "invite-2",
|
||||
Email: "user2@example.com",
|
||||
Name: "User Two",
|
||||
Role: "admin",
|
||||
AutoGroups: nil,
|
||||
},
|
||||
InviteExpiresAt: now.Add(-1 * time.Hour), // Expired
|
||||
InviteCreatedAt: now.Add(-48 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "successful list",
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
return testInvites, nil
|
||||
},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "empty list",
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
return []*types.UserInvite{}, nil
|
||||
},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "permission denied",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
},
|
||||
expectedCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
ListUserInvitesFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/users/invites", nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.listInvites(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp []api.UserInvite
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp, tc.expectedCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInvite(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(72 * time.Hour)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
requestBody string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
|
||||
}{
|
||||
{
|
||||
name: "successful create",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":["group-1"]}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: testInviteID,
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
Role: invite.Role,
|
||||
AutoGroups: invite.AutoGroups,
|
||||
Status: string(types.UserStatusInvited),
|
||||
},
|
||||
InviteToken: testInviteToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful create with custom expiration",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"admin","auto_groups":[],"expires_in":3600}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
assert.Equal(t, 3600, expiresIn)
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: testInviteID,
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
Role: invite.Role,
|
||||
AutoGroups: []string{},
|
||||
Status: string(types.UserStatusInvited),
|
||||
},
|
||||
InviteToken: testInviteToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user already exists",
|
||||
requestBody: `{"email":"existing@example.com","name":"Existing User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusConflict,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite already exists",
|
||||
requestBody: `{"email":"invited@example.com","name":"Invited User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusConflict,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "permission denied",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embedded IDP not enabled",
|
||||
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
requestBody: `{invalid json}`,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
mockFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
CreateUserInviteFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/users/invites", bytes.NewBufferString(tc.requestBody))
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.createInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp api.UserInvite
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testInviteID, resp.Id)
|
||||
assert.NotNil(t, resp.InviteToken)
|
||||
assert.NotEmpty(t, *resp.InviteToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInviteInfo(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
token string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
|
||||
}{
|
||||
{
|
||||
name: "successful get valid invite",
|
||||
token: testInviteToken,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
return &types.UserInviteInfo{
|
||||
Email: testEmail,
|
||||
Name: testName,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
Valid: true,
|
||||
InvitedBy: "Admin User",
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful get expired invite",
|
||||
token: testInviteToken,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
return &types.UserInviteInfo{
|
||||
Email: testEmail,
|
||||
Name: testName,
|
||||
ExpiresAt: now.Add(-24 * time.Hour),
|
||||
Valid: false,
|
||||
InvitedBy: "Admin User",
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite not found",
|
||||
token: "nbi_invalidtoken1234567890123456",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
return nil, status.Errorf(status.NotFound, "invite not found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid token format",
|
||||
token: "invalid",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid invite token")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing token",
|
||||
token: "",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
GetUserInviteInfoFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/users/invites/"+tc.token, nil)
|
||||
if tc.token != "" {
|
||||
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.getInviteInfo(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp api.UserInviteInfo
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testEmail, resp.Email)
|
||||
assert.Equal(t, testName, resp.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcceptInvite(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
token string
|
||||
requestBody string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, token, password string) error
|
||||
}{
|
||||
{
|
||||
name: "successful accept",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite not found",
|
||||
token: "nbi_invalidtoken1234567890123456",
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.NotFound, "invite not found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite expired",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "invite has expired")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embedded IDP not enabled",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing token",
|
||||
token: "",
|
||||
requestBody: `{"password":"SecurePass123!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
token: testInviteToken,
|
||||
requestBody: `{invalid}`,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
mockFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "password too short",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"Short1!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "password must be at least 8 characters long")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "password missing digit",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"NoDigitPass!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "password must contain at least one digit")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "password missing uppercase",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"nouppercase1!"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "password must contain at least one uppercase letter")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "password missing special character",
|
||||
token: testInviteToken,
|
||||
requestBody: `{"password":"NoSpecial123"}`,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: func(ctx context.Context, token, password string) error {
|
||||
return status.Errorf(status.InvalidArgument, "password must contain at least one special character")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
AcceptUserInviteFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.token+"/accept", bytes.NewBufferString(tc.requestBody))
|
||||
if tc.token != "" {
|
||||
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.acceptInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp api.UserInviteAcceptResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.Success)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegenerateInvite(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
expiresAt := now.Add(72 * time.Hour)
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
inviteID string
|
||||
requestBody string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
|
||||
}{
|
||||
{
|
||||
name: "successful regenerate with empty body",
|
||||
inviteID: testInviteID,
|
||||
requestBody: "",
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
assert.Equal(t, 0, expiresIn)
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: inviteID,
|
||||
Email: testEmail,
|
||||
},
|
||||
InviteToken: "nbi_newtoken12345678901234567890",
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful regenerate with custom expiration",
|
||||
inviteID: testInviteID,
|
||||
requestBody: `{"expires_in":7200}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
assert.Equal(t, 7200, expiresIn)
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: inviteID,
|
||||
Email: testEmail,
|
||||
},
|
||||
InviteToken: "nbi_newtoken12345678901234567890",
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite not found",
|
||||
inviteID: "non-existent-invite",
|
||||
requestBody: "",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.Errorf(status.NotFound, "invite not found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "permission denied",
|
||||
inviteID: testInviteID,
|
||||
requestBody: "",
|
||||
expectedStatus: http.StatusForbidden,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing invite ID",
|
||||
inviteID: "",
|
||||
requestBody: "",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON should return error",
|
||||
inviteID: testInviteID,
|
||||
requestBody: `{invalid json}`,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
mockFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
RegenerateUserInviteFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
var body io.Reader
|
||||
if tc.requestBody != "" {
|
||||
body = bytes.NewBufferString(tc.requestBody)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.inviteID+"/regenerate", body)
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
if tc.inviteID != "" {
|
||||
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.regenerateInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
var resp api.UserInviteRegenerateResponse
|
||||
err := json.NewDecoder(rr.Body).Decode(&resp)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.InviteToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteInvite(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
inviteID string
|
||||
expectedStatus int
|
||||
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||
}{
|
||||
{
|
||||
name: "successful delete",
|
||||
inviteID: testInviteID,
|
||||
expectedStatus: http.StatusOK,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invite not found",
|
||||
inviteID: "non-existent-invite",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
return status.Errorf(status.NotFound, "invite not found")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "permission denied",
|
||||
inviteID: testInviteID,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
return status.NewPermissionDeniedError()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embedded IDP not enabled",
|
||||
inviteID: testInviteID,
|
||||
expectedStatus: http.StatusPreconditionFailed,
|
||||
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing invite ID",
|
||||
inviteID: "",
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
mockFunc: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{
|
||||
DeleteUserInviteFunc: tc.mockFunc,
|
||||
}
|
||||
handler := setupInvitesTestHandler(am)
|
||||
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/users/invites/"+tc.inviteID, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
if tc.inviteID != "" {
|
||||
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.deleteInvite(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,14 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
@@ -144,3 +148,25 @@ func (rl *APIRateLimiter) Reset(key string) {
|
||||
defer rl.mu.Unlock()
|
||||
delete(rl.limiters, key)
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that rate limits requests by client IP.
|
||||
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
||||
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := getClientIP(r)
|
||||
if !rl.Allow(clientIP) {
|
||||
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// getClientIP extracts the client IP address from the request.
|
||||
func getClientIP(r *http.Request) string {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
158
management/server/http/middleware/rate_limiter_test.go
Normal file
158
management/server/http/middleware/rate_limiter_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAPIRateLimiter_Allow(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60, // 1 per second
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
// First two requests should be allowed (burst)
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
|
||||
// Third request should be denied (exceeded burst)
|
||||
assert.False(t, rl.Allow("test-key"))
|
||||
|
||||
// Different key should be allowed
|
||||
assert.True(t, rl.Allow("different-key"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_Middleware(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60, // 1 per second
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
// Create a simple handler that returns 200 OK
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Wrap with rate limiter middleware
|
||||
handler := rl.Middleware(nextHandler)
|
||||
|
||||
// First two requests should pass (burst)
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1)
|
||||
}
|
||||
|
||||
// Third request should be rate limited
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rr.Code)
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
handler := rl.Middleware(nextHandler)
|
||||
|
||||
// Request from first IP
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.1:12345"
|
||||
rr1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr1, req1)
|
||||
assert.Equal(t, http.StatusOK, rr1.Code)
|
||||
|
||||
// Second request from first IP should be rate limited
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.1:12345"
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rr2.Code)
|
||||
|
||||
// Request from different IP should be allowed
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req3.RemoteAddr = "192.168.1.2:12345"
|
||||
rr3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr3, req3)
|
||||
assert.Equal(t, http.StatusOK, rr3.Code)
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "remote addr with port",
|
||||
remoteAddr: "192.168.1.1:12345",
|
||||
expected: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "remote addr without port",
|
||||
remoteAddr: "192.168.1.1",
|
||||
expected: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "IPv6 with port",
|
||||
remoteAddr: "[::1]:12345",
|
||||
expected: "::1",
|
||||
},
|
||||
{
|
||||
name: "IPv6 without port",
|
||||
remoteAddr: "::1",
|
||||
expected: "::1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = tc.remoteAddr
|
||||
assert.Equal(t, tc.expected, getClientIP(req))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_Reset(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
// Use up the burst
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
assert.False(t, rl.Allow("test-key"))
|
||||
|
||||
// Reset the limiter
|
||||
rl.Reset("test-key")
|
||||
|
||||
// Should be allowed again
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
}
|
||||
@@ -2,18 +2,54 @@ package instance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
// Version endpoints
|
||||
managementVersionURL = "https://pkgs.netbird.io/releases/latest/version"
|
||||
dashboardReleasesURL = "https://api.github.com/repos/netbirdio/dashboard/releases/latest"
|
||||
|
||||
// Cache TTL for version information
|
||||
versionCacheTTL = 60 * time.Minute
|
||||
|
||||
// HTTP client timeout
|
||||
httpTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// VersionInfo contains version information for NetBird components
|
||||
type VersionInfo struct {
|
||||
// CurrentVersion is the running management server version
|
||||
CurrentVersion string
|
||||
// DashboardVersion is the latest available dashboard version from GitHub
|
||||
DashboardVersion string
|
||||
// ManagementVersion is the latest available management version from GitHub
|
||||
ManagementVersion string
|
||||
// ManagementUpdateAvailable indicates if a newer management version is available
|
||||
ManagementUpdateAvailable bool
|
||||
}
|
||||
|
||||
// githubRelease represents a GitHub release response
|
||||
type githubRelease struct {
|
||||
TagName string `json:"tag_name"`
|
||||
}
|
||||
|
||||
// Manager handles instance-level operations like initial setup.
|
||||
type Manager interface {
|
||||
// IsSetupRequired checks if instance setup is required.
|
||||
@@ -23,6 +59,9 @@ type Manager interface {
|
||||
// CreateOwnerUser creates the initial owner user in the embedded IDP.
|
||||
// This should only be called when IsSetupRequired returns true.
|
||||
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
|
||||
// GetVersionInfo returns version information for NetBird components.
|
||||
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
|
||||
}
|
||||
|
||||
// DefaultManager is the default implementation of Manager.
|
||||
@@ -32,6 +71,12 @@ type DefaultManager struct {
|
||||
|
||||
setupRequired bool
|
||||
setupMu sync.RWMutex
|
||||
|
||||
// Version caching
|
||||
httpClient *http.Client
|
||||
versionMu sync.RWMutex
|
||||
cachedVersions *VersionInfo
|
||||
lastVersionFetch time.Time
|
||||
}
|
||||
|
||||
// NewManager creates a new instance manager.
|
||||
@@ -43,6 +88,9 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager)
|
||||
store: store,
|
||||
embeddedIdpManager: embeddedIdp,
|
||||
setupRequired: false,
|
||||
httpClient: &http.Client{
|
||||
Timeout: httpTimeout,
|
||||
},
|
||||
}
|
||||
|
||||
if embeddedIdp != nil {
|
||||
@@ -134,3 +182,130 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetVersionInfo returns version information for NetBird components.
|
||||
func (m *DefaultManager) GetVersionInfo(ctx context.Context) (*VersionInfo, error) {
|
||||
m.versionMu.RLock()
|
||||
if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL {
|
||||
cached := *m.cachedVersions
|
||||
m.versionMu.RUnlock()
|
||||
return &cached, nil
|
||||
}
|
||||
m.versionMu.RUnlock()
|
||||
|
||||
return m.fetchVersionInfo(ctx)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) fetchVersionInfo(ctx context.Context) (*VersionInfo, error) {
|
||||
m.versionMu.Lock()
|
||||
// Double-check after acquiring write lock
|
||||
if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL {
|
||||
cached := *m.cachedVersions
|
||||
m.versionMu.Unlock()
|
||||
return &cached, nil
|
||||
}
|
||||
m.versionMu.Unlock()
|
||||
|
||||
info := &VersionInfo{
|
||||
CurrentVersion: version.NetbirdVersion(),
|
||||
}
|
||||
|
||||
// Fetch management version from pkgs.netbird.io (plain text)
|
||||
mgmtVersion, err := m.fetchPlainTextVersion(ctx, managementVersionURL)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to fetch management version: %v", err)
|
||||
} else {
|
||||
info.ManagementVersion = mgmtVersion
|
||||
info.ManagementUpdateAvailable = isNewerVersion(info.CurrentVersion, mgmtVersion)
|
||||
}
|
||||
|
||||
// Fetch dashboard version from GitHub
|
||||
dashVersion, err := m.fetchGitHubRelease(ctx, dashboardReleasesURL)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to fetch dashboard version from GitHub: %v", err)
|
||||
} else {
|
||||
info.DashboardVersion = dashVersion
|
||||
}
|
||||
|
||||
// Update cache
|
||||
m.versionMu.Lock()
|
||||
m.cachedVersions = info
|
||||
m.lastVersionFetch = time.Now()
|
||||
m.versionMu.Unlock()
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// isNewerVersion returns true if latestVersion is greater than currentVersion
|
||||
func isNewerVersion(currentVersion, latestVersion string) bool {
|
||||
current, err := goversion.NewVersion(currentVersion)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
latest, err := goversion.NewVersion(latestVersion)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return latest.GreaterThan(current)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) fetchPlainTextVersion(ctx context.Context, url string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion())
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 100))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(body)), nil
|
||||
}
|
||||
|
||||
func (m *DefaultManager) fetchGitHubRelease(ctx context.Context, url string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion())
|
||||
|
||||
resp, err := m.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var release githubRelease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return "", fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
// Remove 'v' prefix if present
|
||||
tag := release.TagName
|
||||
if len(tag) > 0 && tag[0] == 'v' {
|
||||
tag = tag[1:]
|
||||
}
|
||||
|
||||
return tag, nil
|
||||
}
|
||||
|
||||
285
management/server/instance/version_test.go
Normal file
285
management/server/instance/version_test.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package instance
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockRoundTripper implements http.RoundTripper for testing
|
||||
type mockRoundTripper struct {
|
||||
callCount atomic.Int32
|
||||
managementVersion string
|
||||
dashboardVersion string
|
||||
}
|
||||
|
||||
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
m.callCount.Add(1)
|
||||
|
||||
var body string
|
||||
if strings.Contains(req.URL.String(), "pkgs.netbird.io") {
|
||||
// Plain text response for management version
|
||||
body = m.managementVersion
|
||||
} else if strings.Contains(req.URL.String(), "github.com") {
|
||||
// JSON response for dashboard version
|
||||
jsonResp, _ := json.Marshal(githubRelease{TagName: "v" + m.dashboardVersion})
|
||||
body = string(jsonResp)
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewBufferString(body)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestDefaultManager_GetVersionInfo_ReturnsCurrentVersion(t *testing.T) {
|
||||
mockTransport := &mockRoundTripper{
|
||||
managementVersion: "0.65.0",
|
||||
dashboardVersion: "2.10.0",
|
||||
}
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Transport: mockTransport},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
info, err := m.GetVersionInfo(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// CurrentVersion should always be set
|
||||
assert.NotEmpty(t, info.CurrentVersion)
|
||||
assert.Equal(t, "0.65.0", info.ManagementVersion)
|
||||
assert.Equal(t, "2.10.0", info.DashboardVersion)
|
||||
assert.Equal(t, int32(2), mockTransport.callCount.Load()) // 2 calls: management + dashboard
|
||||
}
|
||||
|
||||
func TestDefaultManager_GetVersionInfo_CachesResults(t *testing.T) {
|
||||
mockTransport := &mockRoundTripper{
|
||||
managementVersion: "0.65.0",
|
||||
dashboardVersion: "2.10.0",
|
||||
}
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Transport: mockTransport},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First call
|
||||
info1, err := m.GetVersionInfo(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, info1.CurrentVersion)
|
||||
assert.Equal(t, "0.65.0", info1.ManagementVersion)
|
||||
|
||||
initialCallCount := mockTransport.callCount.Load()
|
||||
|
||||
// Second call should use cache (no additional HTTP calls)
|
||||
info2, err := m.GetVersionInfo(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, info1.CurrentVersion, info2.CurrentVersion)
|
||||
assert.Equal(t, info1.ManagementVersion, info2.ManagementVersion)
|
||||
assert.Equal(t, info1.DashboardVersion, info2.DashboardVersion)
|
||||
|
||||
// Verify no additional HTTP calls were made (cache was used)
|
||||
assert.Equal(t, initialCallCount, mockTransport.callCount.Load())
|
||||
}
|
||||
|
||||
func TestDefaultManager_FetchGitHubRelease_ParsesTagName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tagName string
|
||||
expected string
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "tag with v prefix",
|
||||
tagName: "v1.2.3",
|
||||
expected: "1.2.3",
|
||||
},
|
||||
{
|
||||
name: "tag without v prefix",
|
||||
tagName: "1.2.3",
|
||||
expected: "1.2.3",
|
||||
},
|
||||
{
|
||||
name: "tag with prerelease",
|
||||
tagName: "v2.0.0-beta.1",
|
||||
expected: "2.0.0-beta.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(githubRelease{TagName: tc.tagName})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
version, err := m.fetchGitHubRelease(context.Background(), server.URL)
|
||||
|
||||
if tc.shouldError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expected, version)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManager_FetchGitHubRelease_HandlesErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
}{
|
||||
{
|
||||
name: "not found",
|
||||
statusCode: http.StatusNotFound,
|
||||
body: `{"message": "Not Found"}`,
|
||||
},
|
||||
{
|
||||
name: "rate limited",
|
||||
statusCode: http.StatusForbidden,
|
||||
body: `{"message": "API rate limit exceeded"}`,
|
||||
},
|
||||
{
|
||||
name: "server error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
body: `{"message": "Internal Server Error"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(tc.statusCode)
|
||||
_, _ = w.Write([]byte(tc.body))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
_, err := m.fetchGitHubRelease(context.Background(), server.URL)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManager_FetchGitHubRelease_InvalidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{invalid json}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
_, err := m.fetchGitHubRelease(context.Background(), server.URL)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDefaultManager_FetchGitHubRelease_ContextCancellation(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(1 * time.Second)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(githubRelease{TagName: "v1.0.0"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
m := &DefaultManager{
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
_, err := m.fetchGitHubRelease(ctx, server.URL)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIsNewerVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentVersion string
|
||||
latestVersion string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "latest is newer - minor version",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "0.65.0",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "latest is newer - patch version",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "0.64.2",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "latest is newer - major version",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "1.0.0",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "versions are equal",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "0.64.1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "current is newer - minor version",
|
||||
currentVersion: "0.65.0",
|
||||
latestVersion: "0.64.1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "current is newer - patch version",
|
||||
currentVersion: "0.64.2",
|
||||
latestVersion: "0.64.1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "development version",
|
||||
currentVersion: "development",
|
||||
latestVersion: "0.65.0",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid latest version",
|
||||
currentVersion: "0.64.1",
|
||||
latestVersion: "invalid",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := isNewerVersion(tc.currentVersion, tc.latestVersion)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -139,6 +139,12 @@ type MockAccountManager struct {
|
||||
CreatePeerJobFunc func(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||
GetAllPeerJobsFunc func(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||
GetPeerJobByIDFunc func(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||
CreateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
|
||||
AcceptUserInviteFunc func(ctx context.Context, token, password string) error
|
||||
RegenerateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
|
||||
GetUserInviteInfoFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
|
||||
ListUserInvitesFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
|
||||
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
|
||||
@@ -713,6 +719,48 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
if am.CreateUserInviteFunc != nil {
|
||||
return am.CreateUserInviteFunc(ctx, accountID, initiatorUserID, invite, expiresIn)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUserInvite is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error {
|
||||
if am.AcceptUserInviteFunc != nil {
|
||||
return am.AcceptUserInviteFunc(ctx, token, password)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method AcceptUserInvite is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
if am.RegenerateUserInviteFunc != nil {
|
||||
return am.RegenerateUserInviteFunc(ctx, accountID, initiatorUserID, inviteID, expiresIn)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RegenerateUserInvite is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
if am.GetUserInviteInfoFunc != nil {
|
||||
return am.GetUserInviteInfoFunc(ctx, token)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUserInviteInfo is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
if am.ListUserInvitesFunc != nil {
|
||||
return am.ListUserInvitesFunc(ctx, accountID, initiatorUserID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListUserInvites is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
if am.DeleteUserInviteFunc != nil {
|
||||
return am.DeleteUserInviteFunc(ctx, accountID, initiatorUserID, inviteID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteUserInvite is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
|
||||
if am.GetAccountIDFromUserAuthFunc != nil {
|
||||
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)
|
||||
|
||||
@@ -126,7 +126,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -815,6 +815,130 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// SaveUserInvite saves a user invite to the database
|
||||
func (s *SqlStore) SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error {
|
||||
inviteCopy := invite.Copy()
|
||||
if err := inviteCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt invite: %w", err)
|
||||
}
|
||||
|
||||
result := s.db.Save(inviteCopy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save user invite to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save user invite to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserInviteByID retrieves a user invite by its ID and account ID
|
||||
func (s *SqlStore) GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var invite types.UserInviteRecord
|
||||
result := tx.Where("account_id = ?", accountID).Take(&invite, idQueryCondition, inviteID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "user invite not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get user invite from store")
|
||||
}
|
||||
|
||||
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt invite: %w", err)
|
||||
}
|
||||
|
||||
return &invite, nil
|
||||
}
|
||||
|
||||
// GetUserInviteByHashedToken retrieves a user invite by its hashed token
|
||||
func (s *SqlStore) GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var invite types.UserInviteRecord
|
||||
result := tx.Take(&invite, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "user invite not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get user invite from store")
|
||||
}
|
||||
|
||||
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt invite: %w", err)
|
||||
}
|
||||
|
||||
return &invite, nil
|
||||
}
|
||||
|
||||
// GetUserInviteByEmail retrieves a user invite by account ID and email.
|
||||
// Since email is encrypted with random IVs, we fetch all invites for the account
|
||||
// and compare emails in memory after decryption.
|
||||
func (s *SqlStore) GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var invites []*types.UserInviteRecord
|
||||
result := tx.Find(&invites, "account_id = ?", accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get user invites from store")
|
||||
}
|
||||
|
||||
for _, invite := range invites {
|
||||
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt invite: %w", err)
|
||||
}
|
||||
if strings.EqualFold(invite.Email, email) {
|
||||
return invite, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "user invite not found for email")
|
||||
}
|
||||
|
||||
// GetAccountUserInvites retrieves all user invites for an account
|
||||
func (s *SqlStore) GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var invites []*types.UserInviteRecord
|
||||
result := tx.Find(&invites, "account_id = ?", accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get user invites from store")
|
||||
}
|
||||
|
||||
for _, invite := range invites {
|
||||
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt invite: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return invites, nil
|
||||
}
|
||||
|
||||
// DeleteUserInvite deletes a user invite by its ID
|
||||
func (s *SqlStore) DeleteUserInvite(ctx context.Context, inviteID string) error {
|
||||
result := s.db.Delete(&types.UserInviteRecord{}, idQueryCondition, inviteID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete user invite from store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete user invite from store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
|
||||
520
management/server/store/sql_store_user_invite_test.go
Normal file
520
management/server/store/sql_store_user_invite_test.go
Normal file
@@ -0,0 +1,520 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func TestSqlStore_SaveUserInvite(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-1",
|
||||
AccountID: "account-1",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1", "group-2"},
|
||||
HashedToken: "hashed-token-123",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the invite was saved
|
||||
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
assert.Equal(t, invite.Email, retrieved.Email)
|
||||
assert.Equal(t, invite.Name, retrieved.Name)
|
||||
assert.Equal(t, invite.Role, retrieved.Role)
|
||||
assert.Equal(t, invite.AutoGroups, retrieved.AutoGroups)
|
||||
assert.Equal(t, invite.CreatedBy, retrieved.CreatedBy)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveUserInvite_Update(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-update",
|
||||
AccountID: "account-1",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1"},
|
||||
HashedToken: "hashed-token-123",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update the invite with a new token
|
||||
invite.HashedToken = "new-hashed-token"
|
||||
invite.ExpiresAt = time.Now().Add(24 * time.Hour)
|
||||
|
||||
err = store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the update
|
||||
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "new-hashed-token", retrieved.HashedToken)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserInviteByID(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-get-by-id",
|
||||
AccountID: "account-1",
|
||||
Email: "getbyid@example.com",
|
||||
Name: "Get By ID User",
|
||||
Role: "admin",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-get",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by ID - success
|
||||
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
assert.Equal(t, invite.Email, retrieved.Email)
|
||||
|
||||
// Get by ID - wrong account
|
||||
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, "wrong-account", invite.ID)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Get by ID - not found
|
||||
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, "non-existent")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserInviteByHashedToken(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-get-by-token",
|
||||
AccountID: "account-1",
|
||||
Email: "getbytoken@example.com",
|
||||
Name: "Get By Token User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1"},
|
||||
HashedToken: "unique-hashed-token-456",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by hashed token - success
|
||||
retrieved, err := store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, invite.HashedToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
assert.Equal(t, invite.Email, retrieved.Email)
|
||||
|
||||
// Get by hashed token - not found
|
||||
_, err = store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, "non-existent-token")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserInviteByEmail(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-get-by-email",
|
||||
AccountID: "account-email-test",
|
||||
Email: "unique-email@example.com",
|
||||
Name: "Get By Email User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-email",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get by email - success
|
||||
retrieved, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, invite.Email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
|
||||
// Get by email - case insensitive
|
||||
retrieved, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "UNIQUE-EMAIL@EXAMPLE.COM")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
|
||||
// Get by email - wrong account
|
||||
_, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, "wrong-account", invite.Email)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Get by email - not found
|
||||
_, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "nonexistent@example.com")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountUserInvites(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
accountID := "account-list-invites"
|
||||
|
||||
invites := []*types.UserInviteRecord{
|
||||
{
|
||||
ID: "invite-list-1",
|
||||
AccountID: accountID,
|
||||
Email: "user1@example.com",
|
||||
Name: "User One",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1"},
|
||||
HashedToken: "hashed-token-list-1",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
},
|
||||
{
|
||||
ID: "invite-list-2",
|
||||
AccountID: accountID,
|
||||
Email: "user2@example.com",
|
||||
Name: "User Two",
|
||||
Role: "admin",
|
||||
AutoGroups: []string{"group-2"},
|
||||
HashedToken: "hashed-token-list-2",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
},
|
||||
{
|
||||
ID: "invite-list-3",
|
||||
AccountID: "different-account",
|
||||
Email: "user3@example.com",
|
||||
Name: "User Three",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-list-3",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
},
|
||||
}
|
||||
|
||||
for _, invite := range invites {
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Get all invites for the account
|
||||
retrieved, err := store.GetAccountUserInvites(ctx, LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, retrieved, 2)
|
||||
|
||||
// Verify the invites belong to the correct account
|
||||
for _, invite := range retrieved {
|
||||
assert.Equal(t, accountID, invite.AccountID)
|
||||
}
|
||||
|
||||
// Get invites for account with no invites
|
||||
retrieved, err = store.GetAccountUserInvites(ctx, LockingStrengthNone, "empty-account")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, retrieved, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteUserInvite(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-delete",
|
||||
AccountID: "account-delete-test",
|
||||
Email: "delete@example.com",
|
||||
Name: "Delete User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-delete",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify invite exists
|
||||
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the invite
|
||||
err = store.DeleteUserInvite(ctx, invite.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify invite is deleted
|
||||
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_UserInvite_EncryptedFields(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-encrypted",
|
||||
AccountID: "account-encrypted",
|
||||
Email: "sensitive-email@example.com",
|
||||
Name: "Sensitive Name",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group-1"},
|
||||
HashedToken: "hashed-token-encrypted",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve and verify decryption works
|
||||
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "sensitive-email@example.com", retrieved.Email)
|
||||
assert.Equal(t, "Sensitive Name", retrieved.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteUserInvite_NonExistent(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
// Deleting a non-existent invite should not return an error
|
||||
err := store.DeleteUserInvite(ctx, "non-existent-invite-id")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_UserInvite_SameEmailDifferentAccounts(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
email := "shared-email@example.com"
|
||||
|
||||
// Create invite in first account
|
||||
invite1 := &types.UserInviteRecord{
|
||||
ID: "invite-account1",
|
||||
AccountID: "account-1",
|
||||
Email: email,
|
||||
Name: "User Account 1",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-account1",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-1",
|
||||
}
|
||||
|
||||
// Create invite in second account with same email
|
||||
invite2 := &types.UserInviteRecord{
|
||||
ID: "invite-account2",
|
||||
AccountID: "account-2",
|
||||
Email: email,
|
||||
Name: "User Account 2",
|
||||
Role: "admin",
|
||||
AutoGroups: []string{"group-1"},
|
||||
HashedToken: "hashed-token-account2",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-2",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.SaveUserInvite(ctx, invite2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify each account gets the correct invite by email
|
||||
retrieved1, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-1", email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "invite-account1", retrieved1.ID)
|
||||
assert.Equal(t, "User Account 1", retrieved1.Name)
|
||||
|
||||
retrieved2, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-2", email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "invite-account2", retrieved2.ID)
|
||||
assert.Equal(t, "User Account 2", retrieved2.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_UserInvite_LockingStrength(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-locking",
|
||||
AccountID: "account-locking",
|
||||
Email: "locking@example.com",
|
||||
Name: "Locking Test User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-locking",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test with different locking strengths
|
||||
lockStrengths := []LockingStrength{LockingStrengthNone, LockingStrengthShare, LockingStrengthUpdate}
|
||||
|
||||
for _, strength := range lockStrengths {
|
||||
retrieved, err := store.GetUserInviteByID(ctx, strength, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
|
||||
retrieved, err = store.GetUserInviteByHashedToken(ctx, strength, invite.HashedToken)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
|
||||
retrieved, err = store.GetUserInviteByEmail(ctx, strength, invite.AccountID, invite.Email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, invite.ID, retrieved.ID)
|
||||
|
||||
invites, err := store.GetAccountUserInvites(ctx, strength, invite.AccountID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, invites, 1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_UserInvite_EmptyAutoGroups(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
// Test with nil AutoGroups
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-nil-autogroups",
|
||||
AccountID: "account-autogroups",
|
||||
Email: "nilgroups@example.com",
|
||||
Name: "Nil Groups User",
|
||||
Role: "user",
|
||||
AutoGroups: nil,
|
||||
HashedToken: "hashed-token-nil",
|
||||
ExpiresAt: time.Now().Add(72 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
// Should return empty slice or nil, both are acceptable
|
||||
assert.Empty(t, retrieved.AutoGroups)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_UserInvite_TimestampPrecision(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Skip("store is nil")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC().Truncate(time.Millisecond)
|
||||
expiresAt := now.Add(72 * time.Hour)
|
||||
|
||||
invite := &types.UserInviteRecord{
|
||||
ID: "invite-timestamp",
|
||||
AccountID: "account-timestamp",
|
||||
Email: "timestamp@example.com",
|
||||
Name: "Timestamp User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
HashedToken: "hashed-token-timestamp",
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedAt: now,
|
||||
CreatedBy: "admin-user",
|
||||
}
|
||||
|
||||
err := store.SaveUserInvite(ctx, invite)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify timestamps are preserved (within reasonable precision)
|
||||
assert.WithinDuration(t, now, retrieved.CreatedAt, time.Second)
|
||||
assert.WithinDuration(t, expiresAt, retrieved.ExpiresAt, time.Second)
|
||||
})
|
||||
}
|
||||
@@ -92,6 +92,13 @@ type Store interface {
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
|
||||
SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error
|
||||
GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error)
|
||||
GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error)
|
||||
GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error)
|
||||
GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error)
|
||||
DeleteUserInvite(ctx context.Context, inviteID string) error
|
||||
|
||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
|
||||
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
|
||||
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
|
||||
|
||||
201
management/server/types/user_invite.go
Normal file
201
management/server/types/user_invite.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
b "github.com/hashicorp/go-secure-stdlib/base62"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
const (
|
||||
// InviteTokenPrefix is the prefix for invite tokens
|
||||
InviteTokenPrefix = "nbi_"
|
||||
// InviteTokenSecretLength is the length of the random secret part
|
||||
InviteTokenSecretLength = 30
|
||||
// InviteTokenChecksumLength is the length of the encoded checksum
|
||||
InviteTokenChecksumLength = 6
|
||||
// InviteTokenLength is the total length of the token (4 + 30 + 6 = 40)
|
||||
InviteTokenLength = 40
|
||||
// DefaultInviteExpirationSeconds is the default expiration time for invites (72 hours)
|
||||
DefaultInviteExpirationSeconds = 259200
|
||||
// MinInviteExpirationSeconds is the minimum expiration time for invites (1 hour)
|
||||
MinInviteExpirationSeconds = 3600
|
||||
)
|
||||
|
||||
// UserInviteRecord represents an invitation for a user to set up their account (database model)
|
||||
type UserInviteRecord struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index;not null"`
|
||||
Email string `gorm:"index;not null"`
|
||||
Name string `gorm:"not null"`
|
||||
Role string `gorm:"not null"`
|
||||
AutoGroups []string `gorm:"serializer:json"`
|
||||
HashedToken string `gorm:"index;not null"` // SHA-256 hash of the token (base64 encoded)
|
||||
ExpiresAt time.Time `gorm:"not null"`
|
||||
CreatedAt time.Time `gorm:"not null"`
|
||||
CreatedBy string `gorm:"not null"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for GORM
|
||||
func (UserInviteRecord) TableName() string {
|
||||
return "user_invites"
|
||||
}
|
||||
|
||||
// GenerateInviteToken creates a new invite token with the format: nbi_<secret><checksum>
|
||||
// Returns the hashed token (for storage) and the plain token (to give to the user)
|
||||
func GenerateInviteToken() (hashedToken string, plainToken string, err error) {
|
||||
secret, err := b.Random(InviteTokenSecretLength)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate random secret: %w", err)
|
||||
}
|
||||
|
||||
checksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
encodedChecksum := base62.Encode(checksum)
|
||||
// Left-pad with '0' to ensure exactly 6 characters (fmt.Sprintf %s pads with spaces which breaks base62.Decode)
|
||||
paddedChecksum := encodedChecksum
|
||||
if len(paddedChecksum) < InviteTokenChecksumLength {
|
||||
paddedChecksum = strings.Repeat("0", InviteTokenChecksumLength-len(paddedChecksum)) + paddedChecksum
|
||||
}
|
||||
|
||||
plainToken = InviteTokenPrefix + secret + paddedChecksum
|
||||
hash := sha256.Sum256([]byte(plainToken))
|
||||
hashedToken = b64.StdEncoding.EncodeToString(hash[:])
|
||||
|
||||
return hashedToken, plainToken, nil
|
||||
}
|
||||
|
||||
// HashInviteToken creates a SHA-256 hash of the token (base64 encoded)
|
||||
func HashInviteToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return b64.StdEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// ValidateInviteToken validates the token format and checksum.
|
||||
// Returns an error if the token is invalid.
|
||||
func ValidateInviteToken(token string) error {
|
||||
if len(token) != InviteTokenLength {
|
||||
return fmt.Errorf("invalid token length")
|
||||
}
|
||||
|
||||
prefix := token[:len(InviteTokenPrefix)]
|
||||
if prefix != InviteTokenPrefix {
|
||||
return fmt.Errorf("invalid token prefix")
|
||||
}
|
||||
|
||||
secret := token[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength]
|
||||
encodedChecksum := token[len(InviteTokenPrefix)+InviteTokenSecretLength:]
|
||||
|
||||
verificationChecksum, err := base62.Decode(encodedChecksum)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checksum decoding failed: %w", err)
|
||||
}
|
||||
|
||||
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
if secretChecksum != verificationChecksum {
|
||||
return fmt.Errorf("checksum does not match")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExpired checks if the invite has expired
|
||||
func (i *UserInviteRecord) IsExpired() bool {
|
||||
return time.Now().After(i.ExpiresAt)
|
||||
}
|
||||
|
||||
// UserInvite contains the result of creating or regenerating an invite
|
||||
type UserInvite struct {
|
||||
UserInfo *UserInfo
|
||||
InviteToken string
|
||||
InviteExpiresAt time.Time
|
||||
InviteCreatedAt time.Time
|
||||
}
|
||||
|
||||
// UserInviteInfo contains public information about an invite (for unauthenticated endpoint)
|
||||
type UserInviteInfo struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Valid bool `json:"valid"`
|
||||
InvitedBy string `json:"invited_by"`
|
||||
}
|
||||
|
||||
// NewInviteID generates a new invite ID using xid
|
||||
func NewInviteID() string {
|
||||
return xid.New().String()
|
||||
}
|
||||
|
||||
// EncryptSensitiveData encrypts the invite's sensitive fields (Email and Name) in place.
|
||||
func (i *UserInviteRecord) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if i.Email != "" {
|
||||
i.Email, err = enc.Encrypt(i.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if i.Name != "" {
|
||||
i.Name, err = enc.Encrypt(i.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptSensitiveData decrypts the invite's sensitive fields (Email and Name) in place.
|
||||
func (i *UserInviteRecord) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if i.Email != "" {
|
||||
i.Email, err = enc.Decrypt(i.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if i.Name != "" {
|
||||
i.Name, err = enc.Decrypt(i.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy creates a deep copy of the UserInviteRecord
|
||||
func (i *UserInviteRecord) Copy() *UserInviteRecord {
|
||||
autoGroups := make([]string, len(i.AutoGroups))
|
||||
copy(autoGroups, i.AutoGroups)
|
||||
|
||||
return &UserInviteRecord{
|
||||
ID: i.ID,
|
||||
AccountID: i.AccountID,
|
||||
Email: i.Email,
|
||||
Name: i.Name,
|
||||
Role: i.Role,
|
||||
AutoGroups: autoGroups,
|
||||
HashedToken: i.HashedToken,
|
||||
ExpiresAt: i.ExpiresAt,
|
||||
CreatedAt: i.CreatedAt,
|
||||
CreatedBy: i.CreatedBy,
|
||||
}
|
||||
}
|
||||
355
management/server/types/user_invite_test.go
Normal file
355
management/server/types/user_invite_test.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"hash/crc32"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
func TestUserInviteRecord_TableName(t *testing.T) {
|
||||
invite := UserInviteRecord{}
|
||||
assert.Equal(t, "user_invites", invite.TableName())
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_Success(t *testing.T) {
|
||||
hashedToken, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, hashedToken)
|
||||
assert.NotEmpty(t, plainToken)
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_Length(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, plainToken, InviteTokenLength)
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_Prefix(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, strings.HasPrefix(plainToken, InviteTokenPrefix))
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_Hashing(t *testing.T) {
|
||||
hashedToken, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedHash := sha256.Sum256([]byte(plainToken))
|
||||
expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:])
|
||||
assert.Equal(t, expectedHashedToken, hashedToken)
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_Checksum(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extract parts
|
||||
secret := plainToken[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength]
|
||||
checksumStr := plainToken[len(InviteTokenPrefix)+InviteTokenSecretLength:]
|
||||
|
||||
// Verify checksum
|
||||
expectedChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
actualChecksum, err := base62.Decode(checksumStr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedChecksum, actualChecksum)
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_Uniqueness(t *testing.T) {
|
||||
tokens := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
assert.False(t, tokens[plainToken], "Token should be unique")
|
||||
tokens[plainToken] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashInviteToken(t *testing.T) {
|
||||
token := "nbi_testtoken123456789012345678901234"
|
||||
hashedToken := HashInviteToken(token)
|
||||
|
||||
expectedHash := sha256.Sum256([]byte(token))
|
||||
expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:])
|
||||
assert.Equal(t, expectedHashedToken, hashedToken)
|
||||
}
|
||||
|
||||
func TestHashInviteToken_Consistency(t *testing.T) {
|
||||
token := "nbi_testtoken123456789012345678901234"
|
||||
hash1 := HashInviteToken(token)
|
||||
hash2 := HashInviteToken(token)
|
||||
assert.Equal(t, hash1, hash2)
|
||||
}
|
||||
|
||||
func TestHashInviteToken_DifferentTokens(t *testing.T) {
|
||||
token1 := "nbi_testtoken123456789012345678901234"
|
||||
token2 := "nbi_testtoken123456789012345678901235"
|
||||
hash1 := HashInviteToken(token1)
|
||||
hash2 := HashInviteToken(token2)
|
||||
assert.NotEqual(t, hash1, hash2)
|
||||
}
|
||||
|
||||
func TestValidateInviteToken_Success(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ValidateInviteToken(plainToken)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateInviteToken_InvalidLength(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"empty", ""},
|
||||
{"too short", "nbi_abc"},
|
||||
{"too long", "nbi_" + strings.Repeat("a", 50)},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateInviteToken(tc.token)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid token length")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateInviteToken_InvalidPrefix(t *testing.T) {
|
||||
// Create a token with wrong prefix but correct length
|
||||
token := "xyz_" + strings.Repeat("a", 30) + "000000"
|
||||
err := ValidateInviteToken(token)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid token prefix")
|
||||
}
|
||||
|
||||
func TestValidateInviteToken_InvalidChecksum(t *testing.T) {
|
||||
// Create a token with correct format but invalid checksum
|
||||
token := InviteTokenPrefix + strings.Repeat("a", InviteTokenSecretLength) + "ZZZZZZ"
|
||||
err := ValidateInviteToken(token)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "checksum")
|
||||
}
|
||||
|
||||
func TestValidateInviteToken_ModifiedToken(t *testing.T) {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Modify one character in the secret part
|
||||
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
|
||||
err = ValidateInviteToken(modifiedToken)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUserInviteRecord_IsExpired(t *testing.T) {
|
||||
t.Run("not expired", func(t *testing.T) {
|
||||
invite := &UserInviteRecord{
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
assert.False(t, invite.IsExpired())
|
||||
})
|
||||
|
||||
t.Run("expired", func(t *testing.T) {
|
||||
invite := &UserInviteRecord{
|
||||
ExpiresAt: time.Now().Add(-time.Hour),
|
||||
}
|
||||
assert.True(t, invite.IsExpired())
|
||||
})
|
||||
|
||||
t.Run("just expired", func(t *testing.T) {
|
||||
invite := &UserInviteRecord{
|
||||
ExpiresAt: time.Now().Add(-time.Second),
|
||||
}
|
||||
assert.True(t, invite.IsExpired())
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewInviteID(t *testing.T) {
|
||||
id := NewInviteID()
|
||||
assert.NotEmpty(t, id)
|
||||
assert.Len(t, id, 20) // xid generates 20 character IDs
|
||||
}
|
||||
|
||||
func TestNewInviteID_Uniqueness(t *testing.T) {
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
id := NewInviteID()
|
||||
assert.False(t, ids[id], "ID should be unique")
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserInviteRecord_EncryptDecryptSensitiveData(t *testing.T) {
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("encrypt and decrypt", func(t *testing.T) {
|
||||
invite := &UserInviteRecord{
|
||||
ID: "test-invite",
|
||||
AccountID: "test-account",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
// Encrypt
|
||||
err := invite.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify encrypted values are different from original
|
||||
assert.NotEqual(t, "test@example.com", invite.Email)
|
||||
assert.NotEqual(t, "Test User", invite.Name)
|
||||
|
||||
// Decrypt
|
||||
err = invite.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify decrypted values match original
|
||||
assert.Equal(t, "test@example.com", invite.Email)
|
||||
assert.Equal(t, "Test User", invite.Name)
|
||||
})
|
||||
|
||||
t.Run("encrypt empty fields", func(t *testing.T) {
|
||||
invite := &UserInviteRecord{
|
||||
ID: "test-invite",
|
||||
AccountID: "test-account",
|
||||
Email: "",
|
||||
Name: "",
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
err := invite.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", invite.Email)
|
||||
assert.Equal(t, "", invite.Name)
|
||||
|
||||
err = invite.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "", invite.Email)
|
||||
assert.Equal(t, "", invite.Name)
|
||||
})
|
||||
|
||||
t.Run("nil encryptor", func(t *testing.T) {
|
||||
invite := &UserInviteRecord{
|
||||
ID: "test-invite",
|
||||
AccountID: "test-account",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
Role: "user",
|
||||
}
|
||||
|
||||
err := invite.EncryptSensitiveData(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test@example.com", invite.Email)
|
||||
assert.Equal(t, "Test User", invite.Name)
|
||||
|
||||
err = invite.DecryptSensitiveData(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test@example.com", invite.Email)
|
||||
assert.Equal(t, "Test User", invite.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserInviteRecord_Copy(t *testing.T) {
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(72 * time.Hour)
|
||||
|
||||
original := &UserInviteRecord{
|
||||
ID: "invite-id",
|
||||
AccountID: "account-id",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
HashedToken: "hashed-token",
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedAt: now,
|
||||
CreatedBy: "creator-id",
|
||||
}
|
||||
|
||||
copied := original.Copy()
|
||||
|
||||
// Verify all fields are copied
|
||||
assert.Equal(t, original.ID, copied.ID)
|
||||
assert.Equal(t, original.AccountID, copied.AccountID)
|
||||
assert.Equal(t, original.Email, copied.Email)
|
||||
assert.Equal(t, original.Name, copied.Name)
|
||||
assert.Equal(t, original.Role, copied.Role)
|
||||
assert.Equal(t, original.AutoGroups, copied.AutoGroups)
|
||||
assert.Equal(t, original.HashedToken, copied.HashedToken)
|
||||
assert.Equal(t, original.ExpiresAt, copied.ExpiresAt)
|
||||
assert.Equal(t, original.CreatedAt, copied.CreatedAt)
|
||||
assert.Equal(t, original.CreatedBy, copied.CreatedBy)
|
||||
|
||||
// Verify deep copy of AutoGroups (modifying copy doesn't affect original)
|
||||
copied.AutoGroups[0] = "modified"
|
||||
assert.NotEqual(t, original.AutoGroups[0], copied.AutoGroups[0])
|
||||
assert.Equal(t, "group1", original.AutoGroups[0])
|
||||
}
|
||||
|
||||
func TestUserInviteRecord_Copy_EmptyAutoGroups(t *testing.T) {
|
||||
original := &UserInviteRecord{
|
||||
ID: "invite-id",
|
||||
AccountID: "account-id",
|
||||
AutoGroups: []string{},
|
||||
}
|
||||
|
||||
copied := original.Copy()
|
||||
assert.NotNil(t, copied.AutoGroups)
|
||||
assert.Len(t, copied.AutoGroups, 0)
|
||||
}
|
||||
|
||||
func TestUserInviteRecord_Copy_NilAutoGroups(t *testing.T) {
|
||||
original := &UserInviteRecord{
|
||||
ID: "invite-id",
|
||||
AccountID: "account-id",
|
||||
AutoGroups: nil,
|
||||
}
|
||||
|
||||
copied := original.Copy()
|
||||
assert.NotNil(t, copied.AutoGroups)
|
||||
assert.Len(t, copied.AutoGroups, 0)
|
||||
}
|
||||
|
||||
func TestInviteTokenConstants(t *testing.T) {
|
||||
// Verify constants are consistent
|
||||
expectedLength := len(InviteTokenPrefix) + InviteTokenSecretLength + InviteTokenChecksumLength
|
||||
assert.Equal(t, InviteTokenLength, expectedLength)
|
||||
assert.Equal(t, 4, len(InviteTokenPrefix))
|
||||
assert.Equal(t, 30, InviteTokenSecretLength)
|
||||
assert.Equal(t, 6, InviteTokenChecksumLength)
|
||||
assert.Equal(t, 40, InviteTokenLength)
|
||||
assert.Equal(t, 259200, DefaultInviteExpirationSeconds) // 72 hours
|
||||
assert.Equal(t, 3600, MinInviteExpirationSeconds) // 1 hour
|
||||
}
|
||||
|
||||
func TestGenerateInviteToken_ValidatesOwnOutput(t *testing.T) {
|
||||
// Generate multiple tokens and ensure they all validate
|
||||
for i := 0; i < 50; i++ {
|
||||
_, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ValidateInviteToken(plainToken)
|
||||
assert.NoError(t, err, "Generated token should always be valid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashInviteToken_MatchesGeneratedHash(t *testing.T) {
|
||||
hashedToken, plainToken, err := GenerateInviteToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
// HashInviteToken should produce the same hash as GenerateInviteToken
|
||||
rehashedToken := HashInviteToken(plainToken)
|
||||
assert.Equal(t, hashedToken, rehashedToken)
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
@@ -1453,3 +1454,368 @@ func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, init
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateUserInvite creates an invite link for a new user in the embedded IdP.
|
||||
// The user is NOT created until the invite is accepted.
|
||||
func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
|
||||
if !IsEmbeddedIdp(am.idpManager) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
}
|
||||
|
||||
if err := validateUserInvite(invite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Check if user already exists in NetBird DB
|
||||
existingUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, user := range existingUsers {
|
||||
if strings.EqualFold(user.Email, invite.Email) {
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists")
|
||||
}
|
||||
}
|
||||
|
||||
// Check if invite already exists for this email
|
||||
existingInvite, err := am.Store.GetUserInviteByEmail(ctx, store.LockingStrengthNone, accountID, invite.Email)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return nil, fmt.Errorf("failed to check existing invites: %w", err)
|
||||
}
|
||||
}
|
||||
if existingInvite != nil {
|
||||
return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email")
|
||||
}
|
||||
|
||||
// Calculate expiration time
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = types.DefaultInviteExpirationSeconds
|
||||
}
|
||||
|
||||
if expiresIn < types.MinInviteExpirationSeconds {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour")
|
||||
}
|
||||
expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
// Generate invite token
|
||||
inviteID := types.NewInviteID()
|
||||
hashedToken, plainToken, err := types.GenerateInviteToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate invite token: %w", err)
|
||||
}
|
||||
|
||||
// Create the invite record (no user created yet)
|
||||
userInvite := &types.UserInviteRecord{
|
||||
ID: inviteID,
|
||||
AccountID: accountID,
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
Role: invite.Role,
|
||||
AutoGroups: invite.AutoGroups,
|
||||
HashedToken: hashedToken,
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
CreatedBy: initiatorUserID,
|
||||
}
|
||||
|
||||
if err := am.Store.SaveUserInvite(ctx, userInvite); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkCreated, map[string]any{"email": invite.Email})
|
||||
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: inviteID,
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
Role: invite.Role,
|
||||
AutoGroups: invite.AutoGroups,
|
||||
Status: string(types.UserStatusInvited),
|
||||
Issued: types.UserIssuedAPI,
|
||||
},
|
||||
InviteToken: plainToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserInviteInfo retrieves invite information from a token (public endpoint).
|
||||
func (am *DefaultAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) {
|
||||
if err := types.ValidateInviteToken(token); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid invite token: %v", err)
|
||||
}
|
||||
|
||||
hashedToken := types.HashInviteToken(token)
|
||||
invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthNone, hashedToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the inviter's name
|
||||
invitedBy := ""
|
||||
if invite.CreatedBy != "" {
|
||||
inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, invite.CreatedBy)
|
||||
if err == nil && inviter != nil {
|
||||
invitedBy = inviter.Name
|
||||
}
|
||||
}
|
||||
|
||||
return &types.UserInviteInfo{
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
ExpiresAt: invite.ExpiresAt,
|
||||
Valid: !invite.IsExpired(),
|
||||
InvitedBy: invitedBy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListUserInvites returns all invites for an account.
|
||||
func (am *DefaultAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
|
||||
if !IsEmbeddedIdp(am.idpManager) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
records, err := am.Store.GetAccountUserInvites(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
invites := make([]*types.UserInvite, 0, len(records))
|
||||
for _, record := range records {
|
||||
invites = append(invites, &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: record.ID,
|
||||
Email: record.Email,
|
||||
Name: record.Name,
|
||||
Role: record.Role,
|
||||
AutoGroups: record.AutoGroups,
|
||||
},
|
||||
InviteExpiresAt: record.ExpiresAt,
|
||||
InviteCreatedAt: record.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
return invites, nil
|
||||
}
|
||||
|
||||
// AcceptUserInvite accepts an invite and creates the user in both IdP and NetBird DB.
|
||||
func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error {
|
||||
if !IsEmbeddedIdp(am.idpManager) {
|
||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
}
|
||||
|
||||
if password == "" {
|
||||
return status.Errorf(status.InvalidArgument, "password is required")
|
||||
}
|
||||
|
||||
if err := validatePassword(password); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "invalid password: %v", err)
|
||||
}
|
||||
|
||||
if err := types.ValidateInviteToken(token); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "invalid invite token: %v", err)
|
||||
}
|
||||
|
||||
hashedToken := types.HashInviteToken(token)
|
||||
invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthUpdate, hashedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if invite.IsExpired() {
|
||||
return status.Errorf(status.InvalidArgument, "invite has expired")
|
||||
}
|
||||
|
||||
// Create user in Dex with the provided password
|
||||
embeddedIdp, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
if !ok {
|
||||
return status.Errorf(status.Internal, "failed to get embedded IdP manager")
|
||||
}
|
||||
|
||||
idpUser, err := embeddedIdp.CreateUserWithPassword(ctx, invite.Email, password, invite.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create user in IdP: %w", err)
|
||||
}
|
||||
|
||||
// Create user in NetBird DB
|
||||
newUser := &types.User{
|
||||
Id: idpUser.ID,
|
||||
AccountID: invite.AccountID,
|
||||
Role: types.StrRoleToUserRole(invite.Role),
|
||||
AutoGroups: invite.AutoGroups,
|
||||
Issued: types.UserIssuedAPI,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.SaveUser(ctx, newUser); err != nil {
|
||||
return fmt.Errorf("failed to save user: %w", err)
|
||||
}
|
||||
if err := transaction.DeleteUserInvite(ctx, invite.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete invite: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// Best-effort rollback: delete the IdP user to avoid orphaned records
|
||||
if deleteErr := embeddedIdp.DeleteUser(ctx, idpUser.ID); deleteErr != nil {
|
||||
log.WithContext(ctx).WithError(deleteErr).Errorf("failed to rollback IdP user %s after transaction failure", idpUser.ID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, newUser.Id, newUser.Id, invite.AccountID, activity.UserInviteLinkAccepted, map[string]any{"email": invite.Email})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegenerateUserInvite creates a new invite token for an existing invite, invalidating the previous one.
|
||||
func (am *DefaultAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
|
||||
if !IsEmbeddedIdp(am.idpManager) {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
// Get existing invite
|
||||
existingInvite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Calculate expiration time
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = types.DefaultInviteExpirationSeconds
|
||||
}
|
||||
if expiresIn < types.MinInviteExpirationSeconds {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour")
|
||||
}
|
||||
expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
// Generate new invite token
|
||||
hashedToken, plainToken, err := types.GenerateInviteToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate invite token: %w", err)
|
||||
}
|
||||
|
||||
// Update existing invite with new token and expiration
|
||||
existingInvite.HashedToken = hashedToken
|
||||
existingInvite.ExpiresAt = expiresAt
|
||||
existingInvite.CreatedBy = initiatorUserID
|
||||
|
||||
err = am.Store.SaveUserInvite(ctx, existingInvite)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, existingInvite.ID, accountID, activity.UserInviteLinkRegenerated, map[string]any{"email": existingInvite.Email})
|
||||
|
||||
return &types.UserInvite{
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: existingInvite.ID,
|
||||
Email: existingInvite.Email,
|
||||
Name: existingInvite.Name,
|
||||
Role: existingInvite.Role,
|
||||
AutoGroups: existingInvite.AutoGroups,
|
||||
Status: string(types.UserStatusInvited),
|
||||
Issued: types.UserIssuedAPI,
|
||||
},
|
||||
InviteToken: plainToken,
|
||||
InviteExpiresAt: expiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DeleteUserInvite deletes an existing invite by ID.
|
||||
func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
|
||||
if !IsEmbeddedIdp(am.idpManager) {
|
||||
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
invite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := am.Store.DeleteUserInvite(ctx, inviteID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkDeleted, map[string]any{"email": invite.Email})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const minPasswordLength = 8
|
||||
|
||||
// validatePassword checks password strength requirements:
|
||||
// - Minimum 8 characters
|
||||
// - At least 1 digit
|
||||
// - At least 1 uppercase letter
|
||||
// - At least 1 special character
|
||||
func validatePassword(password string) error {
|
||||
if len(password) < minPasswordLength {
|
||||
return errors.New("password must be at least 8 characters long")
|
||||
}
|
||||
|
||||
var hasDigit, hasUpper, hasSpecial bool
|
||||
for _, c := range password {
|
||||
switch {
|
||||
case unicode.IsDigit(c):
|
||||
hasDigit = true
|
||||
case unicode.IsUpper(c):
|
||||
hasUpper = true
|
||||
case !unicode.IsLetter(c) && !unicode.IsDigit(c):
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
var missing []string
|
||||
if !hasDigit {
|
||||
missing = append(missing, "one digit")
|
||||
}
|
||||
if !hasUpper {
|
||||
missing = append(missing, "one uppercase letter")
|
||||
}
|
||||
if !hasSpecial {
|
||||
missing = append(missing, "one special character")
|
||||
}
|
||||
|
||||
if len(missing) > 0 {
|
||||
return errors.New("password must contain at least " + strings.Join(missing, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
1010
management/server/user_invite_test.go
Normal file
1010
management/server/user_invite_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -488,6 +488,171 @@ components:
|
||||
- role
|
||||
- auto_groups
|
||||
- is_service_user
|
||||
UserInviteCreateRequest:
|
||||
type: object
|
||||
description: Request to create a user invite link
|
||||
properties:
|
||||
email:
|
||||
description: User's email address
|
||||
type: string
|
||||
example: user@example.com
|
||||
name:
|
||||
description: User's full name
|
||||
type: string
|
||||
example: John Doe
|
||||
role:
|
||||
description: User's NetBird account role
|
||||
type: string
|
||||
example: user
|
||||
auto_groups:
|
||||
description: Group IDs to auto-assign to peers registered by this user
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7m0
|
||||
expires_in:
|
||||
description: Invite expiration time in seconds (default 72 hours)
|
||||
type: integer
|
||||
example: 259200
|
||||
required:
|
||||
- email
|
||||
- name
|
||||
- role
|
||||
- auto_groups
|
||||
UserInvite:
|
||||
type: object
|
||||
description: A user invite
|
||||
properties:
|
||||
id:
|
||||
description: Invite ID
|
||||
type: string
|
||||
example: d5p7eedra0h0lt6f59hg
|
||||
email:
|
||||
description: User's email address
|
||||
type: string
|
||||
example: user@example.com
|
||||
name:
|
||||
description: User's full name
|
||||
type: string
|
||||
example: John Doe
|
||||
role:
|
||||
description: User's NetBird account role
|
||||
type: string
|
||||
example: user
|
||||
auto_groups:
|
||||
description: Group IDs to auto-assign to peers registered by this user
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7m0
|
||||
expires_at:
|
||||
description: Invite expiration time
|
||||
type: string
|
||||
format: date-time
|
||||
example: "2024-01-25T10:00:00Z"
|
||||
created_at:
|
||||
description: Invite creation time
|
||||
type: string
|
||||
format: date-time
|
||||
example: "2024-01-22T10:00:00Z"
|
||||
expired:
|
||||
description: Whether the invite has expired
|
||||
type: boolean
|
||||
example: false
|
||||
invite_token:
|
||||
description: The invite link to be shared with the user. Only returned when the invite is created or regenerated.
|
||||
type: string
|
||||
example: nbi_Xk5Lz9mP2vQwRtYu1aN3bC4dE5fGh0ABC123
|
||||
required:
|
||||
- id
|
||||
- email
|
||||
- name
|
||||
- role
|
||||
- auto_groups
|
||||
- expires_at
|
||||
- created_at
|
||||
- expired
|
||||
UserInviteInfo:
|
||||
type: object
|
||||
description: Public information about an invite
|
||||
properties:
|
||||
email:
|
||||
description: User's email address
|
||||
type: string
|
||||
example: user@example.com
|
||||
name:
|
||||
description: User's full name
|
||||
type: string
|
||||
example: John Doe
|
||||
expires_at:
|
||||
description: Invite expiration time
|
||||
type: string
|
||||
format: date-time
|
||||
example: "2024-01-25T10:00:00Z"
|
||||
valid:
|
||||
description: Whether the invite is still valid (not expired)
|
||||
type: boolean
|
||||
example: true
|
||||
invited_by:
|
||||
description: Name of the user who sent the invite
|
||||
type: string
|
||||
example: Admin User
|
||||
required:
|
||||
- email
|
||||
- name
|
||||
- expires_at
|
||||
- valid
|
||||
- invited_by
|
||||
UserInviteAcceptRequest:
|
||||
type: object
|
||||
description: Request to accept an invite and set password
|
||||
properties:
|
||||
password:
|
||||
description: >-
|
||||
The password the user wants to set. Must be at least 8 characters long
|
||||
and contain at least one uppercase letter, one digit, and one special
|
||||
character (any character that is not a letter or digit, including spaces).
|
||||
type: string
|
||||
format: password
|
||||
minLength: 8
|
||||
pattern: '^(?=.*[0-9])(?=.*[A-Z])(?=.*[^a-zA-Z0-9]).{8,}$'
|
||||
example: SecurePass123!
|
||||
required:
|
||||
- password
|
||||
UserInviteAcceptResponse:
|
||||
type: object
|
||||
description: Response after accepting an invite
|
||||
properties:
|
||||
success:
|
||||
description: Whether the invite was accepted successfully
|
||||
type: boolean
|
||||
example: true
|
||||
required:
|
||||
- success
|
||||
UserInviteRegenerateRequest:
|
||||
type: object
|
||||
description: Request to regenerate an invite link
|
||||
properties:
|
||||
expires_in:
|
||||
description: Invite expiration time in seconds (default 72 hours)
|
||||
type: integer
|
||||
example: 259200
|
||||
UserInviteRegenerateResponse:
|
||||
type: object
|
||||
description: Response after regenerating an invite
|
||||
properties:
|
||||
invite_token:
|
||||
description: The new invite token
|
||||
type: string
|
||||
example: nbi_Xk5Lz9mP2vQwRtYu1aN3bC4dE5fGh0ABC123
|
||||
invite_expires_at:
|
||||
description: New invite expiration time
|
||||
type: string
|
||||
format: date-time
|
||||
example: "2024-01-28T10:00:00Z"
|
||||
required:
|
||||
- invite_token
|
||||
- invite_expires_at
|
||||
PeerMinimum:
|
||||
type: object
|
||||
properties:
|
||||
@@ -2071,7 +2236,8 @@ components:
|
||||
"dns.zone.create", "dns.zone.update", "dns.zone.delete",
|
||||
"dns.zone.record.create", "dns.zone.record.update", "dns.zone.record.delete",
|
||||
"peer.job.create",
|
||||
"user.password.change"
|
||||
"user.password.change",
|
||||
"user.invite.link.create", "user.invite.link.accept", "user.invite.link.regenerate", "user.invite.link.delete"
|
||||
]
|
||||
example: route.add
|
||||
initiator_id:
|
||||
@@ -2642,6 +2808,29 @@ components:
|
||||
required:
|
||||
- user_id
|
||||
- email
|
||||
InstanceVersionInfo:
|
||||
type: object
|
||||
description: Version information for NetBird components
|
||||
properties:
|
||||
management_current_version:
|
||||
description: The current running version of the management server
|
||||
type: string
|
||||
example: "0.35.0"
|
||||
dashboard_available_version:
|
||||
description: The latest available version of the dashboard (from GitHub releases)
|
||||
type: string
|
||||
example: "2.10.0"
|
||||
management_available_version:
|
||||
description: The latest available version of the management server (from GitHub releases)
|
||||
type: string
|
||||
example: "0.35.0"
|
||||
management_update_available:
|
||||
description: Indicates if a newer management version is available
|
||||
type: boolean
|
||||
example: true
|
||||
required:
|
||||
- management_current_version
|
||||
- management_update_available
|
||||
responses:
|
||||
not_found:
|
||||
description: Resource not found
|
||||
@@ -2694,6 +2883,27 @@ paths:
|
||||
$ref: '#/components/schemas/InstanceStatus'
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/instance/version:
|
||||
get:
|
||||
summary: Get Version Info
|
||||
description: Returns version information for NetBird components including the current management server version and latest available versions from GitHub.
|
||||
tags: [ Instance ]
|
||||
security:
|
||||
- BearerAuth: []
|
||||
- TokenAuth: []
|
||||
responses:
|
||||
'200':
|
||||
description: Version information
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/InstanceVersionInfo'
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/setup:
|
||||
post:
|
||||
summary: Setup Instance
|
||||
@@ -3312,6 +3522,210 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/invites:
|
||||
get:
|
||||
summary: List user invites
|
||||
description: Lists all pending invites for the account. Only available when embedded IdP is enabled.
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
responses:
|
||||
'200':
|
||||
description: List of invites
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/UserInvite'
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'412':
|
||||
description: Precondition failed - embedded IdP is not enabled
|
||||
content: { }
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
post:
|
||||
summary: Create a user invite
|
||||
description: Creates an invite link for a new user. Only available when embedded IdP is enabled. The user is not created until they accept the invite.
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
requestBody:
|
||||
description: User invite information
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInviteCreateRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: Invite created successfully
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInvite'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'409':
|
||||
description: User or invite already exists
|
||||
content: { }
|
||||
'412':
|
||||
description: Precondition failed - embedded IdP is not enabled
|
||||
content: { }
|
||||
'422':
|
||||
"$ref": "#/components/responses/validation_failed"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/invites/{inviteId}:
|
||||
delete:
|
||||
summary: Delete a user invite
|
||||
description: Deletes a pending invite. Only available when embedded IdP is enabled.
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: inviteId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The ID of the invite to delete
|
||||
responses:
|
||||
'200':
|
||||
description: Invite deleted successfully
|
||||
content: { }
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
description: Invite not found
|
||||
content: { }
|
||||
'412':
|
||||
description: Precondition failed - embedded IdP is not enabled
|
||||
content: { }
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/invites/{inviteId}/regenerate:
|
||||
post:
|
||||
summary: Regenerate a user invite
|
||||
description: Regenerates an invite link for an existing invite. Invalidates the previous token and creates a new one.
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: inviteId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The ID of the invite to regenerate
|
||||
requestBody:
|
||||
description: Regenerate options
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInviteRegenerateRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: Invite regenerated successfully
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInviteRegenerateResponse'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'404':
|
||||
description: Invite not found
|
||||
content: { }
|
||||
'412':
|
||||
description: Precondition failed - embedded IdP is not enabled
|
||||
content: { }
|
||||
'422':
|
||||
"$ref": "#/components/responses/validation_failed"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/invites/{token}:
|
||||
get:
|
||||
summary: Get invite information
|
||||
description: Retrieves public information about an invite. This endpoint is unauthenticated and protected by the token itself.
|
||||
tags: [ Users ]
|
||||
security: []
|
||||
parameters:
|
||||
- in: path
|
||||
name: token
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The invite token
|
||||
responses:
|
||||
'200':
|
||||
description: Invite information
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInviteInfo'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'404':
|
||||
description: Invite not found or invalid token
|
||||
content: { }
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/invites/{token}/accept:
|
||||
post:
|
||||
summary: Accept an invite
|
||||
description: Accepts an invite and creates the user with the provided password. This endpoint is unauthenticated and protected by the token itself.
|
||||
tags: [ Users ]
|
||||
security: []
|
||||
parameters:
|
||||
- in: path
|
||||
name: token
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The invite token
|
||||
requestBody:
|
||||
description: Password to set for the new user
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInviteAcceptRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: Invite accepted successfully
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/UserInviteAcceptResponse'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'404':
|
||||
description: Invite not found or invalid token
|
||||
content: { }
|
||||
'412':
|
||||
description: Precondition failed - embedded IdP is not enabled or invite expired
|
||||
content: { }
|
||||
'422':
|
||||
"$ref": "#/components/responses/validation_failed"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/peers:
|
||||
get:
|
||||
summary: List all Peers
|
||||
|
||||
@@ -123,6 +123,10 @@ const (
|
||||
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
|
||||
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
|
||||
EventActivityCodeUserInvite EventActivityCode = "user.invite"
|
||||
EventActivityCodeUserInviteLinkAccept EventActivityCode = "user.invite.link.accept"
|
||||
EventActivityCodeUserInviteLinkCreate EventActivityCode = "user.invite.link.create"
|
||||
EventActivityCodeUserInviteLinkDelete EventActivityCode = "user.invite.link.delete"
|
||||
EventActivityCodeUserInviteLinkRegenerate EventActivityCode = "user.invite.link.regenerate"
|
||||
EventActivityCodeUserJoin EventActivityCode = "user.join"
|
||||
EventActivityCodeUserPasswordChange EventActivityCode = "user.password.change"
|
||||
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
|
||||
@@ -870,6 +874,21 @@ type InstanceStatus struct {
|
||||
SetupRequired bool `json:"setup_required"`
|
||||
}
|
||||
|
||||
// InstanceVersionInfo Version information for NetBird components
|
||||
type InstanceVersionInfo struct {
|
||||
// DashboardAvailableVersion The latest available version of the dashboard (from GitHub releases)
|
||||
DashboardAvailableVersion *string `json:"dashboard_available_version,omitempty"`
|
||||
|
||||
// ManagementAvailableVersion The latest available version of the management server (from GitHub releases)
|
||||
ManagementAvailableVersion *string `json:"management_available_version,omitempty"`
|
||||
|
||||
// ManagementCurrentVersion The current running version of the management server
|
||||
ManagementCurrentVersion string `json:"management_current_version"`
|
||||
|
||||
// ManagementUpdateAvailable Indicates if a newer management version is available
|
||||
ManagementUpdateAvailable bool `json:"management_update_available"`
|
||||
}
|
||||
|
||||
// JobRequest defines model for JobRequest.
|
||||
type JobRequest struct {
|
||||
Workload WorkloadRequest `json:"workload"`
|
||||
@@ -2166,6 +2185,99 @@ type UserCreateRequest struct {
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// UserInvite A user invite
|
||||
type UserInvite struct {
|
||||
// AutoGroups Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// CreatedAt Invite creation time
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// Email User's email address
|
||||
Email string `json:"email"`
|
||||
|
||||
// Expired Whether the invite has expired
|
||||
Expired bool `json:"expired"`
|
||||
|
||||
// ExpiresAt Invite expiration time
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
|
||||
// Id Invite ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// InviteToken The invite link to be shared with the user. Only returned when the invite is created or regenerated.
|
||||
InviteToken *string `json:"invite_token,omitempty"`
|
||||
|
||||
// Name User's full name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Role User's NetBird account role
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// UserInviteAcceptRequest Request to accept an invite and set password
|
||||
type UserInviteAcceptRequest struct {
|
||||
// Password The password the user wants to set. Must be at least 8 characters long and contain at least one uppercase letter, one digit, and one special character (any character that is not a letter or digit, including spaces).
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// UserInviteAcceptResponse Response after accepting an invite
|
||||
type UserInviteAcceptResponse struct {
|
||||
// Success Whether the invite was accepted successfully
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// UserInviteCreateRequest Request to create a user invite link
|
||||
type UserInviteCreateRequest struct {
|
||||
// AutoGroups Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Email User's email address
|
||||
Email string `json:"email"`
|
||||
|
||||
// ExpiresIn Invite expiration time in seconds (default 72 hours)
|
||||
ExpiresIn *int `json:"expires_in,omitempty"`
|
||||
|
||||
// Name User's full name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Role User's NetBird account role
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
// UserInviteInfo Public information about an invite
|
||||
type UserInviteInfo struct {
|
||||
// Email User's email address
|
||||
Email string `json:"email"`
|
||||
|
||||
// ExpiresAt Invite expiration time
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
|
||||
// InvitedBy Name of the user who sent the invite
|
||||
InvitedBy string `json:"invited_by"`
|
||||
|
||||
// Name User's full name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Valid Whether the invite is still valid (not expired)
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// UserInviteRegenerateRequest Request to regenerate an invite link
|
||||
type UserInviteRegenerateRequest struct {
|
||||
// ExpiresIn Invite expiration time in seconds (default 72 hours)
|
||||
ExpiresIn *int `json:"expires_in,omitempty"`
|
||||
}
|
||||
|
||||
// UserInviteRegenerateResponse Response after regenerating an invite
|
||||
type UserInviteRegenerateResponse struct {
|
||||
// InviteExpiresAt New invite expiration time
|
||||
InviteExpiresAt time.Time `json:"invite_expires_at"`
|
||||
|
||||
// InviteToken The new invite token
|
||||
InviteToken string `json:"invite_token"`
|
||||
}
|
||||
|
||||
// UserPermissions defines model for UserPermissions.
|
||||
type UserPermissions struct {
|
||||
// IsRestricted Indicates whether this User's Peers view is restricted
|
||||
@@ -2418,6 +2530,15 @@ type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest
|
||||
// PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType.
|
||||
type PostApiUsersJSONRequestBody = UserCreateRequest
|
||||
|
||||
// PostApiUsersInvitesJSONRequestBody defines body for PostApiUsersInvites for application/json ContentType.
|
||||
type PostApiUsersInvitesJSONRequestBody = UserInviteCreateRequest
|
||||
|
||||
// PostApiUsersInvitesInviteIdRegenerateJSONRequestBody defines body for PostApiUsersInvitesInviteIdRegenerate for application/json ContentType.
|
||||
type PostApiUsersInvitesInviteIdRegenerateJSONRequestBody = UserInviteRegenerateRequest
|
||||
|
||||
// PostApiUsersInvitesTokenAcceptJSONRequestBody defines body for PostApiUsersInvitesTokenAccept for application/json ContentType.
|
||||
type PostApiUsersInvitesTokenAcceptJSONRequestBody = UserInviteAcceptRequest
|
||||
|
||||
// PutApiUsersUserIdJSONRequestBody defines body for PutApiUsersUserId for application/json ContentType.
|
||||
type PutApiUsersUserIdJSONRequestBody = UserRequest
|
||||
|
||||
|
||||
Reference in New Issue
Block a user