Add user invite link feature for embedded IdP (#5157)

This commit is contained in:
Misha Bragin
2026-01-27 09:42:20 +01:00
committed by GitHub
parent 44ab454a13
commit 7d791620a6
21 changed files with 4832 additions and 2 deletions

View File

@@ -30,6 +30,12 @@ type Manager interface {
autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
AcceptUserInvite(ctx context.Context, token, password string) error
RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error)
ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
UpdateUserPassword(ctx context.Context, accountID, currentUserID, targetUserID string, oldPassword, newPassword string) error

View File

@@ -199,6 +199,11 @@ const (
UserPasswordChanged Activity = 103
UserInviteLinkCreated Activity = 104
UserInviteLinkAccepted Activity = 105
UserInviteLinkRegenerated Activity = 106
UserInviteLinkDeleted Activity = 107
AccountDeleted Activity = 99999
)
@@ -327,6 +332,11 @@ var activityMap = map[Activity]Code{
JobCreatedByUser: {"Create Job for peer", "peer.job.create"},
UserPasswordChanged: {"User password changed", "user.password.change"},
UserInviteLinkCreated: {"User invite link created", "user.invite.link.create"},
UserInviteLinkAccepted: {"User invite link accepted", "user.invite.link.accept"},
UserInviteLinkRegenerated: {"User invite link regenerated", "user.invite.link.regenerate"},
UserInviteLinkDeleted: {"User invite link deleted", "user.invite.link.delete"},
}
// StringCode returns a string code of the activity

View File

@@ -68,6 +68,13 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
if err := bypass.AddBypassPath("/api/setup"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
// Public invite endpoints (tokens start with nbi_)
if err := bypass.AddBypassPath("/api/users/invites/nbi_*"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
if err := bypass.AddBypassPath("/api/users/invites/nbi_*/accept"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
@@ -132,6 +139,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
peers.AddEndpoints(accountManager, router, networkMapController)
users.AddEndpoints(accountManager, router)
users.AddInvitesEndpoints(accountManager, router)
users.AddPublicInvitesEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
@@ -145,6 +154,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
recordsManager.RegisterEndpoints(router, rManager)
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router)
// Mount embedded IdP handler at /oauth2 path if configured
if embeddedIdpEnabled {

View File

@@ -28,6 +28,15 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS")
}
// AddVersionEndpoint registers the authenticated version endpoint.
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
h := &handler{
instanceManager: instanceManager,
}
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
}
// getInstanceStatus returns the instance status including whether setup is required.
// This endpoint is unauthenticated.
func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
@@ -65,3 +74,29 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
Email: userData.Email,
})
}
// getVersionInfo returns version information for NetBird components.
// This endpoint requires authentication.
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)
util.WriteErrorResponse("failed to get version info", http.StatusInternalServerError, w)
return
}
resp := api.InstanceVersionInfo{
ManagementCurrentVersion: versionInfo.CurrentVersion,
ManagementUpdateAvailable: versionInfo.ManagementUpdateAvailable,
}
if versionInfo.DashboardVersion != "" {
resp.DashboardAvailableVersion = &versionInfo.DashboardVersion
}
if versionInfo.ManagementVersion != "" {
resp.ManagementAvailableVersion = &versionInfo.ManagementVersion
}
util.WriteJSONObject(r.Context(), w, resp)
}

View File

@@ -25,6 +25,7 @@ type mockInstanceManager struct {
isSetupRequired bool
isSetupRequiredFn func(ctx context.Context) (bool, error)
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error)
}
func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) {
@@ -66,6 +67,18 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo
}, nil
}
func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) {
if m.getVersionInfoFn != nil {
return m.getVersionInfoFn(ctx)
}
return &nbinstance.VersionInfo{
CurrentVersion: "0.34.0",
DashboardVersion: "2.0.0",
ManagementVersion: "0.35.0",
ManagementUpdateAvailable: true,
}, nil
}
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
@@ -279,3 +292,44 @@ func TestSetup_ManagerError(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestGetVersionInfo_Success(t *testing.T) {
manager := &mockInstanceManager{}
router := mux.NewRouter()
AddVersionEndpoint(manager, router)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.InstanceVersionInfo
err := json.NewDecoder(rec.Body).Decode(&response)
require.NoError(t, err)
assert.Equal(t, "0.34.0", response.ManagementCurrentVersion)
assert.NotNil(t, response.DashboardAvailableVersion)
assert.Equal(t, "2.0.0", *response.DashboardAvailableVersion)
assert.NotNil(t, response.ManagementAvailableVersion)
assert.Equal(t, "0.35.0", *response.ManagementAvailableVersion)
assert.True(t, response.ManagementUpdateAvailable)
}
func TestGetVersionInfo_Error(t *testing.T) {
manager := &mockInstanceManager{
getVersionInfoFn: func(ctx context.Context) (*nbinstance.VersionInfo, error) {
return nil, errors.New("failed to fetch versions")
},
}
router := mux.NewRouter()
AddVersionEndpoint(manager, router)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}

View File

@@ -0,0 +1,263 @@
package users
import (
"encoding/json"
"errors"
"io"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// publicInviteRateLimiter limits public invite requests by IP address to prevent brute-force attacks
var publicInviteRateLimiter = middleware.NewAPIRateLimiter(&middleware.RateLimiterConfig{
RequestsPerMinute: 10, // 10 attempts per minute per IP
Burst: 5, // Allow burst of 5 requests
CleanupInterval: 10 * time.Minute,
LimiterTTL: 30 * time.Minute,
})
// toUserInviteResponse converts a UserInvite to an API response.
func toUserInviteResponse(invite *types.UserInvite) api.UserInvite {
autoGroups := invite.UserInfo.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
var inviteLink *string
if invite.InviteToken != "" {
inviteLink = &invite.InviteToken
}
return api.UserInvite{
Id: invite.UserInfo.ID,
Email: invite.UserInfo.Email,
Name: invite.UserInfo.Name,
Role: invite.UserInfo.Role,
AutoGroups: autoGroups,
ExpiresAt: invite.InviteExpiresAt.UTC(),
CreatedAt: invite.InviteCreatedAt.UTC(),
Expired: time.Now().After(invite.InviteExpiresAt),
InviteToken: inviteLink,
}
}
// invitesHandler handles user invite operations
type invitesHandler struct {
accountManager account.Manager
}
// AddInvitesEndpoints registers invite-related endpoints
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
h := &invitesHandler{accountManager: accountManager}
// Authenticated endpoints (require admin)
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
}
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
h := &invitesHandler{accountManager: accountManager}
// Create a subrouter for public invite endpoints with rate limiting middleware
publicRouter := router.PathPrefix("/users/invites").Subrouter()
publicRouter.Use(publicInviteRateLimiter.Middleware)
// Public endpoints (no auth required, protected by token and rate limited)
publicRouter.HandleFunc("/{token}", h.getInviteInfo).Methods("GET", "OPTIONS")
publicRouter.HandleFunc("/{token}/accept", h.acceptInvite).Methods("POST", "OPTIONS")
}
// listInvites handles GET /api/users/invites
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := make([]api.UserInvite, 0, len(invites))
for _, invite := range invites {
resp = append(resp, toUserInviteResponse(invite))
}
util.WriteJSONObject(r.Context(), w, resp)
}
// createInvite handles POST /api/users/invites
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.UserInviteCreateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
invite := &types.UserInfo{
Email: req.Email,
Name: req.Name,
Role: req.Role,
AutoGroups: req.AutoGroups,
}
expiresIn := 0
if req.ExpiresIn != nil {
expiresIn = *req.ExpiresIn
}
result, err := h.accountManager.CreateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, invite, expiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
result.InviteCreatedAt = time.Now().UTC()
resp := toUserInviteResponse(result)
util.WriteJSONObject(r.Context(), w, &resp)
}
// getInviteInfo handles GET /api/users/invites/{token}
func (h *invitesHandler) getInviteInfo(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
token := vars["token"]
if token == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
return
}
info, err := h.accountManager.GetUserInviteInfo(r.Context(), token)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
expiresAt := info.ExpiresAt.UTC()
util.WriteJSONObject(r.Context(), w, &api.UserInviteInfo{
Email: info.Email,
Name: info.Name,
ExpiresAt: expiresAt,
Valid: info.Valid,
InvitedBy: info.InvitedBy,
})
}
// acceptInvite handles POST /api/users/invites/{token}/accept
func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
token := vars["token"]
if token == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "token is required"), w)
return
}
var req api.UserInviteAcceptRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err := h.accountManager.AcceptUserInvite(r.Context(), token, req.Password)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, &api.UserInviteAcceptResponse{Success: true})
}
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
return
}
var req api.UserInviteRegenerateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
// Allow empty body (io.EOF) - expiresIn is optional
if !errors.Is(err, io.EOF) {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
}
expiresIn := 0
if req.ExpiresIn != nil {
expiresIn = *req.ExpiresIn
}
result, err := h.accountManager.RegenerateUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID, expiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
expiresAt := result.InviteExpiresAt.UTC()
util.WriteJSONObject(r.Context(), w, &api.UserInviteRegenerateResponse{
InviteToken: result.InviteToken,
InviteExpiresAt: expiresAt,
})
}
// deleteInvite handles DELETE /api/users/invites/{inviteId}
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invite ID is required"), w)
return
}
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -0,0 +1,642 @@
package users
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
testInviteID = "test-invite-id"
testInviteToken = "nbi_testtoken123456789012345678"
testEmail = "invite@example.com"
testName = "Test User"
)
func setupInvitesTestHandler(am *mock_server.MockAccountManager) *invitesHandler {
return &invitesHandler{
accountManager: am,
}
}
func TestListInvites(t *testing.T) {
now := time.Now().UTC()
testInvites := []*types.UserInvite{
{
UserInfo: &types.UserInfo{
ID: "invite-1",
Email: "user1@example.com",
Name: "User One",
Role: "user",
AutoGroups: []string{"group-1"},
},
InviteExpiresAt: now.Add(24 * time.Hour),
InviteCreatedAt: now,
},
{
UserInfo: &types.UserInfo{
ID: "invite-2",
Email: "user2@example.com",
Name: "User Two",
Role: "admin",
AutoGroups: nil,
},
InviteExpiresAt: now.Add(-1 * time.Hour), // Expired
InviteCreatedAt: now.Add(-48 * time.Hour),
},
}
tt := []struct {
name string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
expectedCount int
}{
{
name: "successful list",
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
return testInvites, nil
},
expectedCount: 2,
},
{
name: "empty list",
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
return []*types.UserInvite{}, nil
},
expectedCount: 0,
},
{
name: "permission denied",
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
return nil, status.NewPermissionDeniedError()
},
expectedCount: 0,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
ListUserInvitesFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodGet, "/api/users/invites", nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
rr := httptest.NewRecorder()
handler.listInvites(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp []api.UserInvite
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Len(t, resp, tc.expectedCount)
}
})
}
}
func TestCreateInvite(t *testing.T) {
now := time.Now().UTC()
expiresAt := now.Add(72 * time.Hour)
tt := []struct {
name string
requestBody string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
}{
{
name: "successful create",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":["group-1"]}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: testInviteID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: invite.AutoGroups,
Status: string(types.UserStatusInvited),
},
InviteToken: testInviteToken,
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "successful create with custom expiration",
requestBody: `{"email":"test@example.com","name":"Test User","role":"admin","auto_groups":[],"expires_in":3600}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
assert.Equal(t, 3600, expiresIn)
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: testInviteID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: []string{},
Status: string(types.UserStatusInvited),
},
InviteToken: testInviteToken,
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "user already exists",
requestBody: `{"email":"existing@example.com","name":"Existing User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusConflict,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists")
},
},
{
name: "invite already exists",
requestBody: `{"email":"invited@example.com","name":"Invited User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusConflict,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email")
},
},
{
name: "permission denied",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.NewPermissionDeniedError()
},
},
{
name: "embedded IDP not enabled",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "invalid JSON",
requestBody: `{invalid json}`,
expectedStatus: http.StatusBadRequest,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
CreateUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodPost, "/api/users/invites", bytes.NewBufferString(tc.requestBody))
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
rr := httptest.NewRecorder()
handler.createInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInvite
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, testInviteID, resp.Id)
assert.NotNil(t, resp.InviteToken)
assert.NotEmpty(t, *resp.InviteToken)
}
})
}
}
func TestGetInviteInfo(t *testing.T) {
now := time.Now().UTC()
tt := []struct {
name string
token string
expectedStatus int
mockFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
}{
{
name: "successful get valid invite",
token: testInviteToken,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return &types.UserInviteInfo{
Email: testEmail,
Name: testName,
ExpiresAt: now.Add(24 * time.Hour),
Valid: true,
InvitedBy: "Admin User",
}, nil
},
},
{
name: "successful get expired invite",
token: testInviteToken,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return &types.UserInviteInfo{
Email: testEmail,
Name: testName,
ExpiresAt: now.Add(-24 * time.Hour),
Valid: false,
InvitedBy: "Admin User",
}, nil
},
},
{
name: "invite not found",
token: "nbi_invalidtoken1234567890123456",
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return nil, status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "invalid token format",
token: "invalid",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token string) (*types.UserInviteInfo, error) {
return nil, status.Errorf(status.InvalidArgument, "invalid invite token")
},
},
{
name: "missing token",
token: "",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
GetUserInviteInfoFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodGet, "/api/users/invites/"+tc.token, nil)
if tc.token != "" {
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
}
rr := httptest.NewRecorder()
handler.getInviteInfo(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInviteInfo
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, testEmail, resp.Email)
assert.Equal(t, testName, resp.Name)
}
})
}
}
func TestAcceptInvite(t *testing.T) {
tt := []struct {
name string
token string
requestBody string
expectedStatus int
mockFunc func(ctx context.Context, token, password string) error
}{
{
name: "successful accept",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, token, password string) error {
return nil
},
},
{
name: "invite not found",
token: "nbi_invalidtoken1234567890123456",
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "invite expired",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "invite has expired")
},
},
{
name: "embedded IDP not enabled",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "missing token",
token: "",
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
{
name: "invalid JSON",
token: testInviteToken,
requestBody: `{invalid}`,
expectedStatus: http.StatusBadRequest,
mockFunc: nil,
},
{
name: "password too short",
token: testInviteToken,
requestBody: `{"password":"Short1!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must be at least 8 characters long")
},
},
{
name: "password missing digit",
token: testInviteToken,
requestBody: `{"password":"NoDigitPass!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must contain at least one digit")
},
},
{
name: "password missing uppercase",
token: testInviteToken,
requestBody: `{"password":"nouppercase1!"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must contain at least one uppercase letter")
},
},
{
name: "password missing special character",
token: testInviteToken,
requestBody: `{"password":"NoSpecial123"}`,
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.InvalidArgument, "password must contain at least one special character")
},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
AcceptUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.token+"/accept", bytes.NewBufferString(tc.requestBody))
if tc.token != "" {
req = mux.SetURLVars(req, map[string]string{"token": tc.token})
}
rr := httptest.NewRecorder()
handler.acceptInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInviteAcceptResponse
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.True(t, resp.Success)
}
})
}
}
func TestRegenerateInvite(t *testing.T) {
now := time.Now().UTC()
expiresAt := now.Add(72 * time.Hour)
tt := []struct {
name string
inviteID string
requestBody string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
}{
{
name: "successful regenerate with empty body",
inviteID: testInviteID,
requestBody: "",
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
assert.Equal(t, 0, expiresIn)
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: inviteID,
Email: testEmail,
},
InviteToken: "nbi_newtoken12345678901234567890",
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "successful regenerate with custom expiration",
inviteID: testInviteID,
requestBody: `{"expires_in":7200}`,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
assert.Equal(t, 7200, expiresIn)
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: inviteID,
Email: testEmail,
},
InviteToken: "nbi_newtoken12345678901234567890",
InviteExpiresAt: expiresAt,
}, nil
},
},
{
name: "invite not found",
inviteID: "non-existent-invite",
requestBody: "",
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "permission denied",
inviteID: testInviteID,
requestBody: "",
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
return nil, status.NewPermissionDeniedError()
},
},
{
name: "missing invite ID",
inviteID: "",
requestBody: "",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
{
name: "invalid JSON should return error",
inviteID: testInviteID,
requestBody: `{invalid json}`,
expectedStatus: http.StatusBadRequest,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
RegenerateUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
var body io.Reader
if tc.requestBody != "" {
body = bytes.NewBufferString(tc.requestBody)
}
req := httptest.NewRequest(http.MethodPost, "/api/users/invites/"+tc.inviteID+"/regenerate", body)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
if tc.inviteID != "" {
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
}
rr := httptest.NewRecorder()
handler.regenerateInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedStatus == http.StatusOK {
var resp api.UserInviteRegenerateResponse
err := json.NewDecoder(rr.Body).Decode(&resp)
require.NoError(t, err)
assert.NotEmpty(t, resp.InviteToken)
}
})
}
}
func TestDeleteInvite(t *testing.T) {
tt := []struct {
name string
inviteID string
expectedStatus int
mockFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
}{
{
name: "successful delete",
inviteID: testInviteID,
expectedStatus: http.StatusOK,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return nil
},
},
{
name: "invite not found",
inviteID: "non-existent-invite",
expectedStatus: http.StatusNotFound,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return status.Errorf(status.NotFound, "invite not found")
},
},
{
name: "permission denied",
inviteID: testInviteID,
expectedStatus: http.StatusForbidden,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return status.NewPermissionDeniedError()
},
},
{
name: "embedded IDP not enabled",
inviteID: testInviteID,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "missing invite ID",
inviteID: "",
expectedStatus: http.StatusUnprocessableEntity,
mockFunc: nil,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{
DeleteUserInviteFunc: tc.mockFunc,
}
handler := setupInvitesTestHandler(am)
req := httptest.NewRequest(http.MethodDelete, "/api/users/invites/"+tc.inviteID, nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
if tc.inviteID != "" {
req = mux.SetURLVars(req, map[string]string{"inviteId": tc.inviteID})
}
rr := httptest.NewRecorder()
handler.deleteInvite(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
})
}
}

View File

@@ -2,10 +2,14 @@ package middleware
import (
"context"
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
"github.com/netbirdio/netbird/shared/management/http/util"
)
// RateLimiterConfig holds configuration for the API rate limiter
@@ -144,3 +148,25 @@ func (rl *APIRateLimiter) Reset(key string) {
defer rl.mu.Unlock()
delete(rl.limiters, key)
}
// Middleware returns an HTTP middleware that rate limits requests by client IP.
// Returns 429 Too Many Requests if the rate limit is exceeded.
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
if !rl.Allow(clientIP) {
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
return
}
next.ServeHTTP(w, r)
})
}
// getClientIP extracts the client IP address from the request.
func getClientIP(r *http.Request) string {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}

View File

@@ -0,0 +1,158 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestAPIRateLimiter_Allow(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60, // 1 per second
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// First two requests should be allowed (burst)
assert.True(t, rl.Allow("test-key"))
assert.True(t, rl.Allow("test-key"))
// Third request should be denied (exceeded burst)
assert.False(t, rl.Allow("test-key"))
// Different key should be allowed
assert.True(t, rl.Allow("different-key"))
}
func TestAPIRateLimiter_Middleware(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60, // 1 per second
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// Create a simple handler that returns 200 OK
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Wrap with rate limiter middleware
handler := rl.Middleware(nextHandler)
// First two requests should pass (burst)
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1)
}
// Third request should be rate limited
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusTooManyRequests, rr.Code)
}
func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := rl.Middleware(nextHandler)
// Request from first IP
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
req1.RemoteAddr = "192.168.1.1:12345"
rr1 := httptest.NewRecorder()
handler.ServeHTTP(rr1, req1)
assert.Equal(t, http.StatusOK, rr1.Code)
// Second request from first IP should be rate limited
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
req2.RemoteAddr = "192.168.1.1:12345"
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
assert.Equal(t, http.StatusTooManyRequests, rr2.Code)
// Request from different IP should be allowed
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
req3.RemoteAddr = "192.168.1.2:12345"
rr3 := httptest.NewRecorder()
handler.ServeHTTP(rr3, req3)
assert.Equal(t, http.StatusOK, rr3.Code)
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
expected string
}{
{
name: "remote addr with port",
remoteAddr: "192.168.1.1:12345",
expected: "192.168.1.1",
},
{
name: "remote addr without port",
remoteAddr: "192.168.1.1",
expected: "192.168.1.1",
},
{
name: "IPv6 with port",
remoteAddr: "[::1]:12345",
expected: "::1",
},
{
name: "IPv6 without port",
remoteAddr: "::1",
expected: "::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tc.remoteAddr
assert.Equal(t, tc.expected, getClientIP(req))
})
}
}
func TestAPIRateLimiter_Reset(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// Use up the burst
assert.True(t, rl.Allow("test-key"))
assert.False(t, rl.Allow("test-key"))
// Reset the limiter
rl.Reset("test-key")
// Should be allowed again
assert.True(t, rl.Allow("test-key"))
}

View File

@@ -2,18 +2,54 @@ package instance
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/mail"
"strings"
"sync"
"time"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/version"
)
const (
// Version endpoints
managementVersionURL = "https://pkgs.netbird.io/releases/latest/version"
dashboardReleasesURL = "https://api.github.com/repos/netbirdio/dashboard/releases/latest"
// Cache TTL for version information
versionCacheTTL = 60 * time.Minute
// HTTP client timeout
httpTimeout = 5 * time.Second
)
// VersionInfo contains version information for NetBird components
type VersionInfo struct {
// CurrentVersion is the running management server version
CurrentVersion string
// DashboardVersion is the latest available dashboard version from GitHub
DashboardVersion string
// ManagementVersion is the latest available management version from GitHub
ManagementVersion string
// ManagementUpdateAvailable indicates if a newer management version is available
ManagementUpdateAvailable bool
}
// githubRelease represents a GitHub release response
type githubRelease struct {
TagName string `json:"tag_name"`
}
// Manager handles instance-level operations like initial setup.
type Manager interface {
// IsSetupRequired checks if instance setup is required.
@@ -23,6 +59,9 @@ type Manager interface {
// CreateOwnerUser creates the initial owner user in the embedded IDP.
// This should only be called when IsSetupRequired returns true.
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
// GetVersionInfo returns version information for NetBird components.
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
}
// DefaultManager is the default implementation of Manager.
@@ -32,6 +71,12 @@ type DefaultManager struct {
setupRequired bool
setupMu sync.RWMutex
// Version caching
httpClient *http.Client
versionMu sync.RWMutex
cachedVersions *VersionInfo
lastVersionFetch time.Time
}
// NewManager creates a new instance manager.
@@ -43,6 +88,9 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager)
store: store,
embeddedIdpManager: embeddedIdp,
setupRequired: false,
httpClient: &http.Client{
Timeout: httpTimeout,
},
}
if embeddedIdp != nil {
@@ -134,3 +182,130 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
}
return nil
}
// GetVersionInfo returns version information for NetBird components.
func (m *DefaultManager) GetVersionInfo(ctx context.Context) (*VersionInfo, error) {
m.versionMu.RLock()
if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL {
cached := *m.cachedVersions
m.versionMu.RUnlock()
return &cached, nil
}
m.versionMu.RUnlock()
return m.fetchVersionInfo(ctx)
}
func (m *DefaultManager) fetchVersionInfo(ctx context.Context) (*VersionInfo, error) {
m.versionMu.Lock()
// Double-check after acquiring write lock
if m.cachedVersions != nil && time.Since(m.lastVersionFetch) < versionCacheTTL {
cached := *m.cachedVersions
m.versionMu.Unlock()
return &cached, nil
}
m.versionMu.Unlock()
info := &VersionInfo{
CurrentVersion: version.NetbirdVersion(),
}
// Fetch management version from pkgs.netbird.io (plain text)
mgmtVersion, err := m.fetchPlainTextVersion(ctx, managementVersionURL)
if err != nil {
log.WithContext(ctx).Warnf("failed to fetch management version: %v", err)
} else {
info.ManagementVersion = mgmtVersion
info.ManagementUpdateAvailable = isNewerVersion(info.CurrentVersion, mgmtVersion)
}
// Fetch dashboard version from GitHub
dashVersion, err := m.fetchGitHubRelease(ctx, dashboardReleasesURL)
if err != nil {
log.WithContext(ctx).Warnf("failed to fetch dashboard version from GitHub: %v", err)
} else {
info.DashboardVersion = dashVersion
}
// Update cache
m.versionMu.Lock()
m.cachedVersions = info
m.lastVersionFetch = time.Now()
m.versionMu.Unlock()
return info, nil
}
// isNewerVersion returns true if latestVersion is greater than currentVersion
func isNewerVersion(currentVersion, latestVersion string) bool {
current, err := goversion.NewVersion(currentVersion)
if err != nil {
return false
}
latest, err := goversion.NewVersion(latestVersion)
if err != nil {
return false
}
return latest.GreaterThan(current)
}
func (m *DefaultManager) fetchPlainTextVersion(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion())
resp, err := m.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("execute request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 100))
if err != nil {
return "", fmt.Errorf("read response: %w", err)
}
return strings.TrimSpace(string(body)), nil
}
func (m *DefaultManager) fetchGitHubRelease(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "NetBird-Management/"+version.NetbirdVersion())
resp, err := m.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("execute request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var release githubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return "", fmt.Errorf("decode response: %w", err)
}
// Remove 'v' prefix if present
tag := release.TagName
if len(tag) > 0 && tag[0] == 'v' {
tag = tag[1:]
}
return tag, nil
}

View File

@@ -0,0 +1,285 @@
package instance
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockRoundTripper implements http.RoundTripper for testing
type mockRoundTripper struct {
callCount atomic.Int32
managementVersion string
dashboardVersion string
}
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
m.callCount.Add(1)
var body string
if strings.Contains(req.URL.String(), "pkgs.netbird.io") {
// Plain text response for management version
body = m.managementVersion
} else if strings.Contains(req.URL.String(), "github.com") {
// JSON response for dashboard version
jsonResp, _ := json.Marshal(githubRelease{TagName: "v" + m.dashboardVersion})
body = string(jsonResp)
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(body)),
Header: make(http.Header),
}, nil
}
func TestDefaultManager_GetVersionInfo_ReturnsCurrentVersion(t *testing.T) {
mockTransport := &mockRoundTripper{
managementVersion: "0.65.0",
dashboardVersion: "2.10.0",
}
m := &DefaultManager{
httpClient: &http.Client{Transport: mockTransport},
}
ctx := context.Background()
info, err := m.GetVersionInfo(ctx)
require.NoError(t, err)
// CurrentVersion should always be set
assert.NotEmpty(t, info.CurrentVersion)
assert.Equal(t, "0.65.0", info.ManagementVersion)
assert.Equal(t, "2.10.0", info.DashboardVersion)
assert.Equal(t, int32(2), mockTransport.callCount.Load()) // 2 calls: management + dashboard
}
func TestDefaultManager_GetVersionInfo_CachesResults(t *testing.T) {
mockTransport := &mockRoundTripper{
managementVersion: "0.65.0",
dashboardVersion: "2.10.0",
}
m := &DefaultManager{
httpClient: &http.Client{Transport: mockTransport},
}
ctx := context.Background()
// First call
info1, err := m.GetVersionInfo(ctx)
require.NoError(t, err)
assert.NotEmpty(t, info1.CurrentVersion)
assert.Equal(t, "0.65.0", info1.ManagementVersion)
initialCallCount := mockTransport.callCount.Load()
// Second call should use cache (no additional HTTP calls)
info2, err := m.GetVersionInfo(ctx)
require.NoError(t, err)
assert.Equal(t, info1.CurrentVersion, info2.CurrentVersion)
assert.Equal(t, info1.ManagementVersion, info2.ManagementVersion)
assert.Equal(t, info1.DashboardVersion, info2.DashboardVersion)
// Verify no additional HTTP calls were made (cache was used)
assert.Equal(t, initialCallCount, mockTransport.callCount.Load())
}
func TestDefaultManager_FetchGitHubRelease_ParsesTagName(t *testing.T) {
tests := []struct {
name string
tagName string
expected string
shouldError bool
}{
{
name: "tag with v prefix",
tagName: "v1.2.3",
expected: "1.2.3",
},
{
name: "tag without v prefix",
tagName: "1.2.3",
expected: "1.2.3",
},
{
name: "tag with prerelease",
tagName: "v2.0.0-beta.1",
expected: "2.0.0-beta.1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(githubRelease{TagName: tc.tagName})
}))
defer server.Close()
m := &DefaultManager{
httpClient: &http.Client{Timeout: 5 * time.Second},
}
version, err := m.fetchGitHubRelease(context.Background(), server.URL)
if tc.shouldError {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tc.expected, version)
}
})
}
}
func TestDefaultManager_FetchGitHubRelease_HandlesErrors(t *testing.T) {
tests := []struct {
name string
statusCode int
body string
}{
{
name: "not found",
statusCode: http.StatusNotFound,
body: `{"message": "Not Found"}`,
},
{
name: "rate limited",
statusCode: http.StatusForbidden,
body: `{"message": "API rate limit exceeded"}`,
},
{
name: "server error",
statusCode: http.StatusInternalServerError,
body: `{"message": "Internal Server Error"}`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tc.statusCode)
_, _ = w.Write([]byte(tc.body))
}))
defer server.Close()
m := &DefaultManager{
httpClient: &http.Client{Timeout: 5 * time.Second},
}
_, err := m.fetchGitHubRelease(context.Background(), server.URL)
assert.Error(t, err)
})
}
}
func TestDefaultManager_FetchGitHubRelease_InvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{invalid json}`))
}))
defer server.Close()
m := &DefaultManager{
httpClient: &http.Client{Timeout: 5 * time.Second},
}
_, err := m.fetchGitHubRelease(context.Background(), server.URL)
assert.Error(t, err)
}
func TestDefaultManager_FetchGitHubRelease_ContextCancellation(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(1 * time.Second)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(githubRelease{TagName: "v1.0.0"})
}))
defer server.Close()
m := &DefaultManager{
httpClient: &http.Client{Timeout: 5 * time.Second},
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := m.fetchGitHubRelease(ctx, server.URL)
assert.Error(t, err)
}
func TestIsNewerVersion(t *testing.T) {
tests := []struct {
name string
currentVersion string
latestVersion string
expected bool
}{
{
name: "latest is newer - minor version",
currentVersion: "0.64.1",
latestVersion: "0.65.0",
expected: true,
},
{
name: "latest is newer - patch version",
currentVersion: "0.64.1",
latestVersion: "0.64.2",
expected: true,
},
{
name: "latest is newer - major version",
currentVersion: "0.64.1",
latestVersion: "1.0.0",
expected: true,
},
{
name: "versions are equal",
currentVersion: "0.64.1",
latestVersion: "0.64.1",
expected: false,
},
{
name: "current is newer - minor version",
currentVersion: "0.65.0",
latestVersion: "0.64.1",
expected: false,
},
{
name: "current is newer - patch version",
currentVersion: "0.64.2",
latestVersion: "0.64.1",
expected: false,
},
{
name: "development version",
currentVersion: "development",
latestVersion: "0.65.0",
expected: false,
},
{
name: "invalid latest version",
currentVersion: "0.64.1",
latestVersion: "invalid",
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := isNewerVersion(tc.currentVersion, tc.latestVersion)
assert.Equal(t, tc.expected, result)
})
}
}

View File

@@ -139,6 +139,12 @@ type MockAccountManager struct {
CreatePeerJobFunc func(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobsFunc func(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByIDFunc func(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
CreateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error)
AcceptUserInviteFunc func(ctx context.Context, token, password string) error
RegenerateUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error)
GetUserInviteInfoFunc func(ctx context.Context, token string) (*types.UserInviteInfo, error)
ListUserInvitesFunc func(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error)
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
}
func (am *MockAccountManager) CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error {
@@ -713,6 +719,48 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
func (am *MockAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
if am.CreateUserInviteFunc != nil {
return am.CreateUserInviteFunc(ctx, accountID, initiatorUserID, invite, expiresIn)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateUserInvite is not implemented")
}
func (am *MockAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error {
if am.AcceptUserInviteFunc != nil {
return am.AcceptUserInviteFunc(ctx, token, password)
}
return status.Errorf(codes.Unimplemented, "method AcceptUserInvite is not implemented")
}
func (am *MockAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
if am.RegenerateUserInviteFunc != nil {
return am.RegenerateUserInviteFunc(ctx, accountID, initiatorUserID, inviteID, expiresIn)
}
return nil, status.Errorf(codes.Unimplemented, "method RegenerateUserInvite is not implemented")
}
func (am *MockAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) {
if am.GetUserInviteInfoFunc != nil {
return am.GetUserInviteInfoFunc(ctx, token)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUserInviteInfo is not implemented")
}
func (am *MockAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
if am.ListUserInvitesFunc != nil {
return am.ListUserInvitesFunc(ctx, accountID, initiatorUserID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListUserInvites is not implemented")
}
func (am *MockAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
if am.DeleteUserInviteFunc != nil {
return am.DeleteUserInviteFunc(ctx, accountID, initiatorUserID, inviteID)
}
return status.Errorf(codes.Unimplemented, "method DeleteUserInvite is not implemented")
}
func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) {
if am.GetAccountIDFromUserAuthFunc != nil {
return am.GetAccountIDFromUserAuthFunc(ctx, userAuth)

View File

@@ -126,7 +126,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -815,6 +815,130 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
return &user, nil
}
// SaveUserInvite saves a user invite to the database
func (s *SqlStore) SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error {
inviteCopy := invite.Copy()
if err := inviteCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt invite: %w", err)
}
result := s.db.Save(inviteCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save user invite to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save user invite to store")
}
return nil
}
// GetUserInviteByID retrieves a user invite by its ID and account ID
func (s *SqlStore) GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var invite types.UserInviteRecord
result := tx.Where("account_id = ?", accountID).Take(&invite, idQueryCondition, inviteID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user invite not found")
}
log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get user invite from store")
}
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt invite: %w", err)
}
return &invite, nil
}
// GetUserInviteByHashedToken retrieves a user invite by its hashed token
func (s *SqlStore) GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var invite types.UserInviteRecord
result := tx.Take(&invite, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "user invite not found")
}
log.WithContext(ctx).Errorf("failed to get user invite from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get user invite from store")
}
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt invite: %w", err)
}
return &invite, nil
}
// GetUserInviteByEmail retrieves a user invite by account ID and email.
// Since email is encrypted with random IVs, we fetch all invites for the account
// and compare emails in memory after decryption.
func (s *SqlStore) GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var invites []*types.UserInviteRecord
result := tx.Find(&invites, "account_id = ?", accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get user invites from store")
}
for _, invite := range invites {
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt invite: %w", err)
}
if strings.EqualFold(invite.Email, email) {
return invite, nil
}
}
return nil, status.Errorf(status.NotFound, "user invite not found for email")
}
// GetAccountUserInvites retrieves all user invites for an account
func (s *SqlStore) GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var invites []*types.UserInviteRecord
result := tx.Find(&invites, "account_id = ?", accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get user invites from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get user invites from store")
}
for _, invite := range invites {
if err := invite.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt invite: %w", err)
}
}
return invites, nil
}
// DeleteUserInvite deletes a user invite by its ID
func (s *SqlStore) DeleteUserInvite(ctx context.Context, inviteID string) error {
result := s.db.Delete(&types.UserInviteRecord{}, idQueryCondition, inviteID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete user invite from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete user invite from store")
}
return nil
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
tx := s.db
if lockStrength != LockingStrengthNone {

View File

@@ -0,0 +1,520 @@
package store
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/types"
)
func TestSqlStore_SaveUserInvite(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-1",
AccountID: "account-1",
Email: "test@example.com",
Name: "Test User",
Role: "user",
AutoGroups: []string{"group-1", "group-2"},
HashedToken: "hashed-token-123",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Verify the invite was saved
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
assert.Equal(t, invite.Email, retrieved.Email)
assert.Equal(t, invite.Name, retrieved.Name)
assert.Equal(t, invite.Role, retrieved.Role)
assert.Equal(t, invite.AutoGroups, retrieved.AutoGroups)
assert.Equal(t, invite.CreatedBy, retrieved.CreatedBy)
})
}
func TestSqlStore_SaveUserInvite_Update(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-update",
AccountID: "account-1",
Email: "test@example.com",
Name: "Test User",
Role: "user",
AutoGroups: []string{"group-1"},
HashedToken: "hashed-token-123",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Update the invite with a new token
invite.HashedToken = "new-hashed-token"
invite.ExpiresAt = time.Now().Add(24 * time.Hour)
err = store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Verify the update
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, "new-hashed-token", retrieved.HashedToken)
})
}
func TestSqlStore_GetUserInviteByID(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-get-by-id",
AccountID: "account-1",
Email: "getbyid@example.com",
Name: "Get By ID User",
Role: "admin",
AutoGroups: []string{},
HashedToken: "hashed-token-get",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Get by ID - success
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
assert.Equal(t, invite.Email, retrieved.Email)
// Get by ID - wrong account
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, "wrong-account", invite.ID)
assert.Error(t, err)
// Get by ID - not found
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, "non-existent")
assert.Error(t, err)
})
}
func TestSqlStore_GetUserInviteByHashedToken(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-get-by-token",
AccountID: "account-1",
Email: "getbytoken@example.com",
Name: "Get By Token User",
Role: "user",
AutoGroups: []string{"group-1"},
HashedToken: "unique-hashed-token-456",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Get by hashed token - success
retrieved, err := store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, invite.HashedToken)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
assert.Equal(t, invite.Email, retrieved.Email)
// Get by hashed token - not found
_, err = store.GetUserInviteByHashedToken(ctx, LockingStrengthNone, "non-existent-token")
assert.Error(t, err)
})
}
func TestSqlStore_GetUserInviteByEmail(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-get-by-email",
AccountID: "account-email-test",
Email: "unique-email@example.com",
Name: "Get By Email User",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-email",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Get by email - success
retrieved, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, invite.Email)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
// Get by email - case insensitive
retrieved, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "UNIQUE-EMAIL@EXAMPLE.COM")
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
// Get by email - wrong account
_, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, "wrong-account", invite.Email)
assert.Error(t, err)
// Get by email - not found
_, err = store.GetUserInviteByEmail(ctx, LockingStrengthNone, invite.AccountID, "nonexistent@example.com")
assert.Error(t, err)
})
}
func TestSqlStore_GetAccountUserInvites(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
accountID := "account-list-invites"
invites := []*types.UserInviteRecord{
{
ID: "invite-list-1",
AccountID: accountID,
Email: "user1@example.com",
Name: "User One",
Role: "user",
AutoGroups: []string{"group-1"},
HashedToken: "hashed-token-list-1",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
},
{
ID: "invite-list-2",
AccountID: accountID,
Email: "user2@example.com",
Name: "User Two",
Role: "admin",
AutoGroups: []string{"group-2"},
HashedToken: "hashed-token-list-2",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
},
{
ID: "invite-list-3",
AccountID: "different-account",
Email: "user3@example.com",
Name: "User Three",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-list-3",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
},
}
for _, invite := range invites {
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
}
// Get all invites for the account
retrieved, err := store.GetAccountUserInvites(ctx, LockingStrengthNone, accountID)
require.NoError(t, err)
assert.Len(t, retrieved, 2)
// Verify the invites belong to the correct account
for _, invite := range retrieved {
assert.Equal(t, accountID, invite.AccountID)
}
// Get invites for account with no invites
retrieved, err = store.GetAccountUserInvites(ctx, LockingStrengthNone, "empty-account")
require.NoError(t, err)
assert.Len(t, retrieved, 0)
})
}
func TestSqlStore_DeleteUserInvite(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-delete",
AccountID: "account-delete-test",
Email: "delete@example.com",
Name: "Delete User",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-delete",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Verify invite exists
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
// Delete the invite
err = store.DeleteUserInvite(ctx, invite.ID)
require.NoError(t, err)
// Verify invite is deleted
_, err = store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
assert.Error(t, err)
})
}
func TestSqlStore_UserInvite_EncryptedFields(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-encrypted",
AccountID: "account-encrypted",
Email: "sensitive-email@example.com",
Name: "Sensitive Name",
Role: "user",
AutoGroups: []string{"group-1"},
HashedToken: "hashed-token-encrypted",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Retrieve and verify decryption works
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, "sensitive-email@example.com", retrieved.Email)
assert.Equal(t, "Sensitive Name", retrieved.Name)
})
}
func TestSqlStore_DeleteUserInvite_NonExistent(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
// Deleting a non-existent invite should not return an error
err := store.DeleteUserInvite(ctx, "non-existent-invite-id")
require.NoError(t, err)
})
}
func TestSqlStore_UserInvite_SameEmailDifferentAccounts(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
email := "shared-email@example.com"
// Create invite in first account
invite1 := &types.UserInviteRecord{
ID: "invite-account1",
AccountID: "account-1",
Email: email,
Name: "User Account 1",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-account1",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-1",
}
// Create invite in second account with same email
invite2 := &types.UserInviteRecord{
ID: "invite-account2",
AccountID: "account-2",
Email: email,
Name: "User Account 2",
Role: "admin",
AutoGroups: []string{"group-1"},
HashedToken: "hashed-token-account2",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-2",
}
err := store.SaveUserInvite(ctx, invite1)
require.NoError(t, err)
err = store.SaveUserInvite(ctx, invite2)
require.NoError(t, err)
// Verify each account gets the correct invite by email
retrieved1, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-1", email)
require.NoError(t, err)
assert.Equal(t, "invite-account1", retrieved1.ID)
assert.Equal(t, "User Account 1", retrieved1.Name)
retrieved2, err := store.GetUserInviteByEmail(ctx, LockingStrengthNone, "account-2", email)
require.NoError(t, err)
assert.Equal(t, "invite-account2", retrieved2.ID)
assert.Equal(t, "User Account 2", retrieved2.Name)
})
}
func TestSqlStore_UserInvite_LockingStrength(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
invite := &types.UserInviteRecord{
ID: "invite-locking",
AccountID: "account-locking",
Email: "locking@example.com",
Name: "Locking Test User",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-locking",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
// Test with different locking strengths
lockStrengths := []LockingStrength{LockingStrengthNone, LockingStrengthShare, LockingStrengthUpdate}
for _, strength := range lockStrengths {
retrieved, err := store.GetUserInviteByID(ctx, strength, invite.AccountID, invite.ID)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
retrieved, err = store.GetUserInviteByHashedToken(ctx, strength, invite.HashedToken)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
retrieved, err = store.GetUserInviteByEmail(ctx, strength, invite.AccountID, invite.Email)
require.NoError(t, err)
assert.Equal(t, invite.ID, retrieved.ID)
invites, err := store.GetAccountUserInvites(ctx, strength, invite.AccountID)
require.NoError(t, err)
assert.Len(t, invites, 1)
}
})
}
func TestSqlStore_UserInvite_EmptyAutoGroups(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
// Test with nil AutoGroups
invite := &types.UserInviteRecord{
ID: "invite-nil-autogroups",
AccountID: "account-autogroups",
Email: "nilgroups@example.com",
Name: "Nil Groups User",
Role: "user",
AutoGroups: nil,
HashedToken: "hashed-token-nil",
ExpiresAt: time.Now().Add(72 * time.Hour),
CreatedAt: time.Now(),
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
// Should return empty slice or nil, both are acceptable
assert.Empty(t, retrieved.AutoGroups)
})
}
func TestSqlStore_UserInvite_TimestampPrecision(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Skip("store is nil")
}
ctx := context.Background()
now := time.Now().UTC().Truncate(time.Millisecond)
expiresAt := now.Add(72 * time.Hour)
invite := &types.UserInviteRecord{
ID: "invite-timestamp",
AccountID: "account-timestamp",
Email: "timestamp@example.com",
Name: "Timestamp User",
Role: "user",
AutoGroups: []string{},
HashedToken: "hashed-token-timestamp",
ExpiresAt: expiresAt,
CreatedAt: now,
CreatedBy: "admin-user",
}
err := store.SaveUserInvite(ctx, invite)
require.NoError(t, err)
retrieved, err := store.GetUserInviteByID(ctx, LockingStrengthNone, invite.AccountID, invite.ID)
require.NoError(t, err)
// Verify timestamps are preserved (within reasonable precision)
assert.WithinDuration(t, now, retrieved.CreatedAt, time.Second)
assert.WithinDuration(t, expiresAt, retrieved.ExpiresAt, time.Second)
})
}

View File

@@ -92,6 +92,13 @@ type Store interface {
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
SaveUserInvite(ctx context.Context, invite *types.UserInviteRecord) error
GetUserInviteByID(ctx context.Context, lockStrength LockingStrength, accountID, inviteID string) (*types.UserInviteRecord, error)
GetUserInviteByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.UserInviteRecord, error)
GetUserInviteByEmail(ctx context.Context, lockStrength LockingStrength, accountID, email string) (*types.UserInviteRecord, error)
GetAccountUserInvites(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.UserInviteRecord, error)
DeleteUserInvite(ctx context.Context, inviteID string) error
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)

View File

@@ -0,0 +1,201 @@
package types
import (
"crypto/sha256"
b64 "encoding/base64"
"fmt"
"hash/crc32"
"strings"
"time"
b "github.com/hashicorp/go-secure-stdlib/base62"
"github.com/rs/xid"
"github.com/netbirdio/netbird/base62"
"github.com/netbirdio/netbird/util/crypt"
)
const (
// InviteTokenPrefix is the prefix for invite tokens
InviteTokenPrefix = "nbi_"
// InviteTokenSecretLength is the length of the random secret part
InviteTokenSecretLength = 30
// InviteTokenChecksumLength is the length of the encoded checksum
InviteTokenChecksumLength = 6
// InviteTokenLength is the total length of the token (4 + 30 + 6 = 40)
InviteTokenLength = 40
// DefaultInviteExpirationSeconds is the default expiration time for invites (72 hours)
DefaultInviteExpirationSeconds = 259200
// MinInviteExpirationSeconds is the minimum expiration time for invites (1 hour)
MinInviteExpirationSeconds = 3600
)
// UserInviteRecord represents an invitation for a user to set up their account (database model)
type UserInviteRecord struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index;not null"`
Email string `gorm:"index;not null"`
Name string `gorm:"not null"`
Role string `gorm:"not null"`
AutoGroups []string `gorm:"serializer:json"`
HashedToken string `gorm:"index;not null"` // SHA-256 hash of the token (base64 encoded)
ExpiresAt time.Time `gorm:"not null"`
CreatedAt time.Time `gorm:"not null"`
CreatedBy string `gorm:"not null"`
}
// TableName returns the table name for GORM
func (UserInviteRecord) TableName() string {
return "user_invites"
}
// GenerateInviteToken creates a new invite token with the format: nbi_<secret><checksum>
// Returns the hashed token (for storage) and the plain token (to give to the user)
func GenerateInviteToken() (hashedToken string, plainToken string, err error) {
secret, err := b.Random(InviteTokenSecretLength)
if err != nil {
return "", "", fmt.Errorf("failed to generate random secret: %w", err)
}
checksum := crc32.ChecksumIEEE([]byte(secret))
encodedChecksum := base62.Encode(checksum)
// Left-pad with '0' to ensure exactly 6 characters (fmt.Sprintf %s pads with spaces which breaks base62.Decode)
paddedChecksum := encodedChecksum
if len(paddedChecksum) < InviteTokenChecksumLength {
paddedChecksum = strings.Repeat("0", InviteTokenChecksumLength-len(paddedChecksum)) + paddedChecksum
}
plainToken = InviteTokenPrefix + secret + paddedChecksum
hash := sha256.Sum256([]byte(plainToken))
hashedToken = b64.StdEncoding.EncodeToString(hash[:])
return hashedToken, plainToken, nil
}
// HashInviteToken creates a SHA-256 hash of the token (base64 encoded)
func HashInviteToken(token string) string {
hash := sha256.Sum256([]byte(token))
return b64.StdEncoding.EncodeToString(hash[:])
}
// ValidateInviteToken validates the token format and checksum.
// Returns an error if the token is invalid.
func ValidateInviteToken(token string) error {
if len(token) != InviteTokenLength {
return fmt.Errorf("invalid token length")
}
prefix := token[:len(InviteTokenPrefix)]
if prefix != InviteTokenPrefix {
return fmt.Errorf("invalid token prefix")
}
secret := token[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength]
encodedChecksum := token[len(InviteTokenPrefix)+InviteTokenSecretLength:]
verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil {
return fmt.Errorf("checksum decoding failed: %w", err)
}
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum {
return fmt.Errorf("checksum does not match")
}
return nil
}
// IsExpired checks if the invite has expired
func (i *UserInviteRecord) IsExpired() bool {
return time.Now().After(i.ExpiresAt)
}
// UserInvite contains the result of creating or regenerating an invite
type UserInvite struct {
UserInfo *UserInfo
InviteToken string
InviteExpiresAt time.Time
InviteCreatedAt time.Time
}
// UserInviteInfo contains public information about an invite (for unauthenticated endpoint)
type UserInviteInfo struct {
Email string `json:"email"`
Name string `json:"name"`
ExpiresAt time.Time `json:"expires_at"`
Valid bool `json:"valid"`
InvitedBy string `json:"invited_by"`
}
// NewInviteID generates a new invite ID using xid
func NewInviteID() string {
return xid.New().String()
}
// EncryptSensitiveData encrypts the invite's sensitive fields (Email and Name) in place.
func (i *UserInviteRecord) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
var err error
if i.Email != "" {
i.Email, err = enc.Encrypt(i.Email)
if err != nil {
return fmt.Errorf("encrypt email: %w", err)
}
}
if i.Name != "" {
i.Name, err = enc.Encrypt(i.Name)
if err != nil {
return fmt.Errorf("encrypt name: %w", err)
}
}
return nil
}
// DecryptSensitiveData decrypts the invite's sensitive fields (Email and Name) in place.
func (i *UserInviteRecord) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
var err error
if i.Email != "" {
i.Email, err = enc.Decrypt(i.Email)
if err != nil {
return fmt.Errorf("decrypt email: %w", err)
}
}
if i.Name != "" {
i.Name, err = enc.Decrypt(i.Name)
if err != nil {
return fmt.Errorf("decrypt name: %w", err)
}
}
return nil
}
// Copy creates a deep copy of the UserInviteRecord
func (i *UserInviteRecord) Copy() *UserInviteRecord {
autoGroups := make([]string, len(i.AutoGroups))
copy(autoGroups, i.AutoGroups)
return &UserInviteRecord{
ID: i.ID,
AccountID: i.AccountID,
Email: i.Email,
Name: i.Name,
Role: i.Role,
AutoGroups: autoGroups,
HashedToken: i.HashedToken,
ExpiresAt: i.ExpiresAt,
CreatedAt: i.CreatedAt,
CreatedBy: i.CreatedBy,
}
}

View File

@@ -0,0 +1,355 @@
package types
import (
"crypto/sha256"
b64 "encoding/base64"
"hash/crc32"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/base62"
"github.com/netbirdio/netbird/util/crypt"
)
func TestUserInviteRecord_TableName(t *testing.T) {
invite := UserInviteRecord{}
assert.Equal(t, "user_invites", invite.TableName())
}
func TestGenerateInviteToken_Success(t *testing.T) {
hashedToken, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
assert.NotEmpty(t, hashedToken)
assert.NotEmpty(t, plainToken)
}
func TestGenerateInviteToken_Length(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
assert.Len(t, plainToken, InviteTokenLength)
}
func TestGenerateInviteToken_Prefix(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
assert.True(t, strings.HasPrefix(plainToken, InviteTokenPrefix))
}
func TestGenerateInviteToken_Hashing(t *testing.T) {
hashedToken, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
expectedHash := sha256.Sum256([]byte(plainToken))
expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:])
assert.Equal(t, expectedHashedToken, hashedToken)
}
func TestGenerateInviteToken_Checksum(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
// Extract parts
secret := plainToken[len(InviteTokenPrefix) : len(InviteTokenPrefix)+InviteTokenSecretLength]
checksumStr := plainToken[len(InviteTokenPrefix)+InviteTokenSecretLength:]
// Verify checksum
expectedChecksum := crc32.ChecksumIEEE([]byte(secret))
actualChecksum, err := base62.Decode(checksumStr)
require.NoError(t, err)
assert.Equal(t, expectedChecksum, actualChecksum)
}
func TestGenerateInviteToken_Uniqueness(t *testing.T) {
tokens := make(map[string]bool)
for i := 0; i < 100; i++ {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
assert.False(t, tokens[plainToken], "Token should be unique")
tokens[plainToken] = true
}
}
func TestHashInviteToken(t *testing.T) {
token := "nbi_testtoken123456789012345678901234"
hashedToken := HashInviteToken(token)
expectedHash := sha256.Sum256([]byte(token))
expectedHashedToken := b64.StdEncoding.EncodeToString(expectedHash[:])
assert.Equal(t, expectedHashedToken, hashedToken)
}
func TestHashInviteToken_Consistency(t *testing.T) {
token := "nbi_testtoken123456789012345678901234"
hash1 := HashInviteToken(token)
hash2 := HashInviteToken(token)
assert.Equal(t, hash1, hash2)
}
func TestHashInviteToken_DifferentTokens(t *testing.T) {
token1 := "nbi_testtoken123456789012345678901234"
token2 := "nbi_testtoken123456789012345678901235"
hash1 := HashInviteToken(token1)
hash2 := HashInviteToken(token2)
assert.NotEqual(t, hash1, hash2)
}
func TestValidateInviteToken_Success(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
err = ValidateInviteToken(plainToken)
assert.NoError(t, err)
}
func TestValidateInviteToken_InvalidLength(t *testing.T) {
testCases := []struct {
name string
token string
}{
{"empty", ""},
{"too short", "nbi_abc"},
{"too long", "nbi_" + strings.Repeat("a", 50)},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateInviteToken(tc.token)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid token length")
})
}
}
func TestValidateInviteToken_InvalidPrefix(t *testing.T) {
// Create a token with wrong prefix but correct length
token := "xyz_" + strings.Repeat("a", 30) + "000000"
err := ValidateInviteToken(token)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid token prefix")
}
func TestValidateInviteToken_InvalidChecksum(t *testing.T) {
// Create a token with correct format but invalid checksum
token := InviteTokenPrefix + strings.Repeat("a", InviteTokenSecretLength) + "ZZZZZZ"
err := ValidateInviteToken(token)
require.Error(t, err)
assert.Contains(t, err.Error(), "checksum")
}
func TestValidateInviteToken_ModifiedToken(t *testing.T) {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
// Modify one character in the secret part
modifiedToken := plainToken[:5] + "X" + plainToken[6:]
err = ValidateInviteToken(modifiedToken)
require.Error(t, err)
}
func TestUserInviteRecord_IsExpired(t *testing.T) {
t.Run("not expired", func(t *testing.T) {
invite := &UserInviteRecord{
ExpiresAt: time.Now().Add(time.Hour),
}
assert.False(t, invite.IsExpired())
})
t.Run("expired", func(t *testing.T) {
invite := &UserInviteRecord{
ExpiresAt: time.Now().Add(-time.Hour),
}
assert.True(t, invite.IsExpired())
})
t.Run("just expired", func(t *testing.T) {
invite := &UserInviteRecord{
ExpiresAt: time.Now().Add(-time.Second),
}
assert.True(t, invite.IsExpired())
})
}
func TestNewInviteID(t *testing.T) {
id := NewInviteID()
assert.NotEmpty(t, id)
assert.Len(t, id, 20) // xid generates 20 character IDs
}
func TestNewInviteID_Uniqueness(t *testing.T) {
ids := make(map[string]bool)
for i := 0; i < 100; i++ {
id := NewInviteID()
assert.False(t, ids[id], "ID should be unique")
ids[id] = true
}
}
func TestUserInviteRecord_EncryptDecryptSensitiveData(t *testing.T) {
key, err := crypt.GenerateKey()
require.NoError(t, err)
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
require.NoError(t, err)
t.Run("encrypt and decrypt", func(t *testing.T) {
invite := &UserInviteRecord{
ID: "test-invite",
AccountID: "test-account",
Email: "test@example.com",
Name: "Test User",
Role: "user",
}
// Encrypt
err := invite.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
// Verify encrypted values are different from original
assert.NotEqual(t, "test@example.com", invite.Email)
assert.NotEqual(t, "Test User", invite.Name)
// Decrypt
err = invite.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
// Verify decrypted values match original
assert.Equal(t, "test@example.com", invite.Email)
assert.Equal(t, "Test User", invite.Name)
})
t.Run("encrypt empty fields", func(t *testing.T) {
invite := &UserInviteRecord{
ID: "test-invite",
AccountID: "test-account",
Email: "",
Name: "",
Role: "user",
}
err := invite.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", invite.Email)
assert.Equal(t, "", invite.Name)
err = invite.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", invite.Email)
assert.Equal(t, "", invite.Name)
})
t.Run("nil encryptor", func(t *testing.T) {
invite := &UserInviteRecord{
ID: "test-invite",
AccountID: "test-account",
Email: "test@example.com",
Name: "Test User",
Role: "user",
}
err := invite.EncryptSensitiveData(nil)
require.NoError(t, err)
assert.Equal(t, "test@example.com", invite.Email)
assert.Equal(t, "Test User", invite.Name)
err = invite.DecryptSensitiveData(nil)
require.NoError(t, err)
assert.Equal(t, "test@example.com", invite.Email)
assert.Equal(t, "Test User", invite.Name)
})
}
func TestUserInviteRecord_Copy(t *testing.T) {
now := time.Now()
expiresAt := now.Add(72 * time.Hour)
original := &UserInviteRecord{
ID: "invite-id",
AccountID: "account-id",
Email: "test@example.com",
Name: "Test User",
Role: "user",
AutoGroups: []string{"group1", "group2"},
HashedToken: "hashed-token",
ExpiresAt: expiresAt,
CreatedAt: now,
CreatedBy: "creator-id",
}
copied := original.Copy()
// Verify all fields are copied
assert.Equal(t, original.ID, copied.ID)
assert.Equal(t, original.AccountID, copied.AccountID)
assert.Equal(t, original.Email, copied.Email)
assert.Equal(t, original.Name, copied.Name)
assert.Equal(t, original.Role, copied.Role)
assert.Equal(t, original.AutoGroups, copied.AutoGroups)
assert.Equal(t, original.HashedToken, copied.HashedToken)
assert.Equal(t, original.ExpiresAt, copied.ExpiresAt)
assert.Equal(t, original.CreatedAt, copied.CreatedAt)
assert.Equal(t, original.CreatedBy, copied.CreatedBy)
// Verify deep copy of AutoGroups (modifying copy doesn't affect original)
copied.AutoGroups[0] = "modified"
assert.NotEqual(t, original.AutoGroups[0], copied.AutoGroups[0])
assert.Equal(t, "group1", original.AutoGroups[0])
}
func TestUserInviteRecord_Copy_EmptyAutoGroups(t *testing.T) {
original := &UserInviteRecord{
ID: "invite-id",
AccountID: "account-id",
AutoGroups: []string{},
}
copied := original.Copy()
assert.NotNil(t, copied.AutoGroups)
assert.Len(t, copied.AutoGroups, 0)
}
func TestUserInviteRecord_Copy_NilAutoGroups(t *testing.T) {
original := &UserInviteRecord{
ID: "invite-id",
AccountID: "account-id",
AutoGroups: nil,
}
copied := original.Copy()
assert.NotNil(t, copied.AutoGroups)
assert.Len(t, copied.AutoGroups, 0)
}
func TestInviteTokenConstants(t *testing.T) {
// Verify constants are consistent
expectedLength := len(InviteTokenPrefix) + InviteTokenSecretLength + InviteTokenChecksumLength
assert.Equal(t, InviteTokenLength, expectedLength)
assert.Equal(t, 4, len(InviteTokenPrefix))
assert.Equal(t, 30, InviteTokenSecretLength)
assert.Equal(t, 6, InviteTokenChecksumLength)
assert.Equal(t, 40, InviteTokenLength)
assert.Equal(t, 259200, DefaultInviteExpirationSeconds) // 72 hours
assert.Equal(t, 3600, MinInviteExpirationSeconds) // 1 hour
}
func TestGenerateInviteToken_ValidatesOwnOutput(t *testing.T) {
// Generate multiple tokens and ensure they all validate
for i := 0; i < 50; i++ {
_, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
err = ValidateInviteToken(plainToken)
assert.NoError(t, err, "Generated token should always be valid")
}
}
func TestHashInviteToken_MatchesGeneratedHash(t *testing.T) {
hashedToken, plainToken, err := GenerateInviteToken()
require.NoError(t, err)
// HashInviteToken should produce the same hash as GenerateInviteToken
rehashedToken := HashInviteToken(plainToken)
assert.Equal(t, hashedToken, rehashedToken)
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"strings"
"time"
"unicode"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/auth"
@@ -1453,3 +1454,368 @@ func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, init
return nil
}
// CreateUserInvite creates an invite link for a new user in the embedded IdP.
// The user is NOT created until the invite is accepted.
func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
if !IsEmbeddedIdp(am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
if err := validateUserInvite(invite); err != nil {
return nil, err
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
// Check if user already exists in NetBird DB
existingUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
for _, user := range existingUsers {
if strings.EqualFold(user.Email, invite.Email) {
return nil, status.Errorf(status.UserAlreadyExists, "user with this email already exists")
}
}
// Check if invite already exists for this email
existingInvite, err := am.Store.GetUserInviteByEmail(ctx, store.LockingStrengthNone, accountID, invite.Email)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return nil, fmt.Errorf("failed to check existing invites: %w", err)
}
}
if existingInvite != nil {
return nil, status.Errorf(status.AlreadyExists, "invite already exists for this email")
}
// Calculate expiration time
if expiresIn <= 0 {
expiresIn = types.DefaultInviteExpirationSeconds
}
if expiresIn < types.MinInviteExpirationSeconds {
return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour")
}
expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second)
// Generate invite token
inviteID := types.NewInviteID()
hashedToken, plainToken, err := types.GenerateInviteToken()
if err != nil {
return nil, fmt.Errorf("failed to generate invite token: %w", err)
}
// Create the invite record (no user created yet)
userInvite := &types.UserInviteRecord{
ID: inviteID,
AccountID: accountID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: invite.AutoGroups,
HashedToken: hashedToken,
ExpiresAt: expiresAt,
CreatedAt: time.Now().UTC(),
CreatedBy: initiatorUserID,
}
if err := am.Store.SaveUserInvite(ctx, userInvite); err != nil {
return nil, err
}
am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkCreated, map[string]any{"email": invite.Email})
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: inviteID,
Email: invite.Email,
Name: invite.Name,
Role: invite.Role,
AutoGroups: invite.AutoGroups,
Status: string(types.UserStatusInvited),
Issued: types.UserIssuedAPI,
},
InviteToken: plainToken,
InviteExpiresAt: expiresAt,
}, nil
}
// GetUserInviteInfo retrieves invite information from a token (public endpoint).
func (am *DefaultAccountManager) GetUserInviteInfo(ctx context.Context, token string) (*types.UserInviteInfo, error) {
if err := types.ValidateInviteToken(token); err != nil {
return nil, status.Errorf(status.InvalidArgument, "invalid invite token: %v", err)
}
hashedToken := types.HashInviteToken(token)
invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthNone, hashedToken)
if err != nil {
return nil, err
}
// Get the inviter's name
invitedBy := ""
if invite.CreatedBy != "" {
inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, invite.CreatedBy)
if err == nil && inviter != nil {
invitedBy = inviter.Name
}
}
return &types.UserInviteInfo{
Email: invite.Email,
Name: invite.Name,
ExpiresAt: invite.ExpiresAt,
Valid: !invite.IsExpired(),
InvitedBy: invitedBy,
}, nil
}
// ListUserInvites returns all invites for an account.
func (am *DefaultAccountManager) ListUserInvites(ctx context.Context, accountID, initiatorUserID string) ([]*types.UserInvite, error) {
if !IsEmbeddedIdp(am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
records, err := am.Store.GetAccountUserInvites(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
invites := make([]*types.UserInvite, 0, len(records))
for _, record := range records {
invites = append(invites, &types.UserInvite{
UserInfo: &types.UserInfo{
ID: record.ID,
Email: record.Email,
Name: record.Name,
Role: record.Role,
AutoGroups: record.AutoGroups,
},
InviteExpiresAt: record.ExpiresAt,
InviteCreatedAt: record.CreatedAt,
})
}
return invites, nil
}
// AcceptUserInvite accepts an invite and creates the user in both IdP and NetBird DB.
func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, password string) error {
if !IsEmbeddedIdp(am.idpManager) {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
if password == "" {
return status.Errorf(status.InvalidArgument, "password is required")
}
if err := validatePassword(password); err != nil {
return status.Errorf(status.InvalidArgument, "invalid password: %v", err)
}
if err := types.ValidateInviteToken(token); err != nil {
return status.Errorf(status.InvalidArgument, "invalid invite token: %v", err)
}
hashedToken := types.HashInviteToken(token)
invite, err := am.Store.GetUserInviteByHashedToken(ctx, store.LockingStrengthUpdate, hashedToken)
if err != nil {
return err
}
if invite.IsExpired() {
return status.Errorf(status.InvalidArgument, "invite has expired")
}
// Create user in Dex with the provided password
embeddedIdp, ok := am.idpManager.(*idp.EmbeddedIdPManager)
if !ok {
return status.Errorf(status.Internal, "failed to get embedded IdP manager")
}
idpUser, err := embeddedIdp.CreateUserWithPassword(ctx, invite.Email, password, invite.Name)
if err != nil {
return fmt.Errorf("failed to create user in IdP: %w", err)
}
// Create user in NetBird DB
newUser := &types.User{
Id: idpUser.ID,
AccountID: invite.AccountID,
Role: types.StrRoleToUserRole(invite.Role),
AutoGroups: invite.AutoGroups,
Issued: types.UserIssuedAPI,
CreatedAt: time.Now().UTC(),
Email: invite.Email,
Name: invite.Name,
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.SaveUser(ctx, newUser); err != nil {
return fmt.Errorf("failed to save user: %w", err)
}
if err := transaction.DeleteUserInvite(ctx, invite.ID); err != nil {
return fmt.Errorf("failed to delete invite: %w", err)
}
return nil
})
if err != nil {
// Best-effort rollback: delete the IdP user to avoid orphaned records
if deleteErr := embeddedIdp.DeleteUser(ctx, idpUser.ID); deleteErr != nil {
log.WithContext(ctx).WithError(deleteErr).Errorf("failed to rollback IdP user %s after transaction failure", idpUser.ID)
}
return err
}
am.StoreEvent(ctx, newUser.Id, newUser.Id, invite.AccountID, activity.UserInviteLinkAccepted, map[string]any{"email": invite.Email})
return nil
}
// RegenerateUserInvite creates a new invite token for an existing invite, invalidating the previous one.
func (am *DefaultAccountManager) RegenerateUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string, expiresIn int) (*types.UserInvite, error) {
if !IsEmbeddedIdp(am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
// Get existing invite
existingInvite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID)
if err != nil {
return nil, err
}
// Calculate expiration time
if expiresIn <= 0 {
expiresIn = types.DefaultInviteExpirationSeconds
}
if expiresIn < types.MinInviteExpirationSeconds {
return nil, status.Errorf(status.InvalidArgument, "invite expiration must be at least 1 hour")
}
expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second)
// Generate new invite token
hashedToken, plainToken, err := types.GenerateInviteToken()
if err != nil {
return nil, fmt.Errorf("failed to generate invite token: %w", err)
}
// Update existing invite with new token and expiration
existingInvite.HashedToken = hashedToken
existingInvite.ExpiresAt = expiresAt
existingInvite.CreatedBy = initiatorUserID
err = am.Store.SaveUserInvite(ctx, existingInvite)
if err != nil {
return nil, err
}
am.StoreEvent(ctx, initiatorUserID, existingInvite.ID, accountID, activity.UserInviteLinkRegenerated, map[string]any{"email": existingInvite.Email})
return &types.UserInvite{
UserInfo: &types.UserInfo{
ID: existingInvite.ID,
Email: existingInvite.Email,
Name: existingInvite.Name,
Role: existingInvite.Role,
AutoGroups: existingInvite.AutoGroups,
Status: string(types.UserStatusInvited),
Issued: types.UserIssuedAPI,
},
InviteToken: plainToken,
InviteExpiresAt: expiresAt,
}, nil
}
// DeleteUserInvite deletes an existing invite by ID.
func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID, initiatorUserID, inviteID string) error {
if !IsEmbeddedIdp(am.idpManager) {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
invite, err := am.Store.GetUserInviteByID(ctx, store.LockingStrengthUpdate, accountID, inviteID)
if err != nil {
return err
}
if err := am.Store.DeleteUserInvite(ctx, inviteID); err != nil {
return err
}
am.StoreEvent(ctx, initiatorUserID, inviteID, accountID, activity.UserInviteLinkDeleted, map[string]any{"email": invite.Email})
return nil
}
const minPasswordLength = 8
// validatePassword checks password strength requirements:
// - Minimum 8 characters
// - At least 1 digit
// - At least 1 uppercase letter
// - At least 1 special character
func validatePassword(password string) error {
if len(password) < minPasswordLength {
return errors.New("password must be at least 8 characters long")
}
var hasDigit, hasUpper, hasSpecial bool
for _, c := range password {
switch {
case unicode.IsDigit(c):
hasDigit = true
case unicode.IsUpper(c):
hasUpper = true
case !unicode.IsLetter(c) && !unicode.IsDigit(c):
hasSpecial = true
}
}
var missing []string
if !hasDigit {
missing = append(missing, "one digit")
}
if !hasUpper {
missing = append(missing, "one uppercase letter")
}
if !hasSpecial {
missing = append(missing, "one special character")
}
if len(missing) > 0 {
return errors.New("password must contain at least " + strings.Join(missing, ", "))
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -488,6 +488,171 @@ components:
- role
- auto_groups
- is_service_user
UserInviteCreateRequest:
type: object
description: Request to create a user invite link
properties:
email:
description: User's email address
type: string
example: user@example.com
name:
description: User's full name
type: string
example: John Doe
role:
description: User's NetBird account role
type: string
example: user
auto_groups:
description: Group IDs to auto-assign to peers registered by this user
type: array
items:
type: string
example: ch8i4ug6lnn4g9hqv7m0
expires_in:
description: Invite expiration time in seconds (default 72 hours)
type: integer
example: 259200
required:
- email
- name
- role
- auto_groups
UserInvite:
type: object
description: A user invite
properties:
id:
description: Invite ID
type: string
example: d5p7eedra0h0lt6f59hg
email:
description: User's email address
type: string
example: user@example.com
name:
description: User's full name
type: string
example: John Doe
role:
description: User's NetBird account role
type: string
example: user
auto_groups:
description: Group IDs to auto-assign to peers registered by this user
type: array
items:
type: string
example: ch8i4ug6lnn4g9hqv7m0
expires_at:
description: Invite expiration time
type: string
format: date-time
example: "2024-01-25T10:00:00Z"
created_at:
description: Invite creation time
type: string
format: date-time
example: "2024-01-22T10:00:00Z"
expired:
description: Whether the invite has expired
type: boolean
example: false
invite_token:
description: The invite link to be shared with the user. Only returned when the invite is created or regenerated.
type: string
example: nbi_Xk5Lz9mP2vQwRtYu1aN3bC4dE5fGh0ABC123
required:
- id
- email
- name
- role
- auto_groups
- expires_at
- created_at
- expired
UserInviteInfo:
type: object
description: Public information about an invite
properties:
email:
description: User's email address
type: string
example: user@example.com
name:
description: User's full name
type: string
example: John Doe
expires_at:
description: Invite expiration time
type: string
format: date-time
example: "2024-01-25T10:00:00Z"
valid:
description: Whether the invite is still valid (not expired)
type: boolean
example: true
invited_by:
description: Name of the user who sent the invite
type: string
example: Admin User
required:
- email
- name
- expires_at
- valid
- invited_by
UserInviteAcceptRequest:
type: object
description: Request to accept an invite and set password
properties:
password:
description: >-
The password the user wants to set. Must be at least 8 characters long
and contain at least one uppercase letter, one digit, and one special
character (any character that is not a letter or digit, including spaces).
type: string
format: password
minLength: 8
pattern: '^(?=.*[0-9])(?=.*[A-Z])(?=.*[^a-zA-Z0-9]).{8,}$'
example: SecurePass123!
required:
- password
UserInviteAcceptResponse:
type: object
description: Response after accepting an invite
properties:
success:
description: Whether the invite was accepted successfully
type: boolean
example: true
required:
- success
UserInviteRegenerateRequest:
type: object
description: Request to regenerate an invite link
properties:
expires_in:
description: Invite expiration time in seconds (default 72 hours)
type: integer
example: 259200
UserInviteRegenerateResponse:
type: object
description: Response after regenerating an invite
properties:
invite_token:
description: The new invite token
type: string
example: nbi_Xk5Lz9mP2vQwRtYu1aN3bC4dE5fGh0ABC123
invite_expires_at:
description: New invite expiration time
type: string
format: date-time
example: "2024-01-28T10:00:00Z"
required:
- invite_token
- invite_expires_at
PeerMinimum:
type: object
properties:
@@ -2071,7 +2236,8 @@ components:
"dns.zone.create", "dns.zone.update", "dns.zone.delete",
"dns.zone.record.create", "dns.zone.record.update", "dns.zone.record.delete",
"peer.job.create",
"user.password.change"
"user.password.change",
"user.invite.link.create", "user.invite.link.accept", "user.invite.link.regenerate", "user.invite.link.delete"
]
example: route.add
initiator_id:
@@ -2642,6 +2808,29 @@ components:
required:
- user_id
- email
InstanceVersionInfo:
type: object
description: Version information for NetBird components
properties:
management_current_version:
description: The current running version of the management server
type: string
example: "0.35.0"
dashboard_available_version:
description: The latest available version of the dashboard (from GitHub releases)
type: string
example: "2.10.0"
management_available_version:
description: The latest available version of the management server (from GitHub releases)
type: string
example: "0.35.0"
management_update_available:
description: Indicates if a newer management version is available
type: boolean
example: true
required:
- management_current_version
- management_update_available
responses:
not_found:
description: Resource not found
@@ -2694,6 +2883,27 @@ paths:
$ref: '#/components/schemas/InstanceStatus'
'500':
"$ref": "#/components/responses/internal_error"
/api/instance/version:
get:
summary: Get Version Info
description: Returns version information for NetBird components including the current management server version and latest available versions from GitHub.
tags: [ Instance ]
security:
- BearerAuth: []
- TokenAuth: []
responses:
'200':
description: Version information
content:
application/json:
schema:
$ref: '#/components/schemas/InstanceVersionInfo'
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/setup:
post:
summary: Setup Instance
@@ -3312,6 +3522,210 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/invites:
get:
summary: List user invites
description: Lists all pending invites for the account. Only available when embedded IdP is enabled.
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: List of invites
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/UserInvite'
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'412':
description: Precondition failed - embedded IdP is not enabled
content: { }
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a user invite
description: Creates an invite link for a new user. Only available when embedded IdP is enabled. The user is not created until they accept the invite.
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: User invite information
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/UserInviteCreateRequest'
responses:
'200':
description: Invite created successfully
content:
application/json:
schema:
$ref: '#/components/schemas/UserInvite'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'409':
description: User or invite already exists
content: { }
'412':
description: Precondition failed - embedded IdP is not enabled
content: { }
'422':
"$ref": "#/components/responses/validation_failed"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/invites/{inviteId}:
delete:
summary: Delete a user invite
description: Deletes a pending invite. Only available when embedded IdP is enabled.
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: inviteId
required: true
schema:
type: string
description: The ID of the invite to delete
responses:
'200':
description: Invite deleted successfully
content: { }
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
description: Invite not found
content: { }
'412':
description: Precondition failed - embedded IdP is not enabled
content: { }
'500':
"$ref": "#/components/responses/internal_error"
/api/users/invites/{inviteId}/regenerate:
post:
summary: Regenerate a user invite
description: Regenerates an invite link for an existing invite. Invalidates the previous token and creates a new one.
tags: [ Users ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: inviteId
required: true
schema:
type: string
description: The ID of the invite to regenerate
requestBody:
description: Regenerate options
content:
application/json:
schema:
$ref: '#/components/schemas/UserInviteRegenerateRequest'
responses:
'200':
description: Invite regenerated successfully
content:
application/json:
schema:
$ref: '#/components/schemas/UserInviteRegenerateResponse'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
description: Invite not found
content: { }
'412':
description: Precondition failed - embedded IdP is not enabled
content: { }
'422':
"$ref": "#/components/responses/validation_failed"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/invites/{token}:
get:
summary: Get invite information
description: Retrieves public information about an invite. This endpoint is unauthenticated and protected by the token itself.
tags: [ Users ]
security: []
parameters:
- in: path
name: token
required: true
schema:
type: string
description: The invite token
responses:
'200':
description: Invite information
content:
application/json:
schema:
$ref: '#/components/schemas/UserInviteInfo'
'400':
"$ref": "#/components/responses/bad_request"
'404':
description: Invite not found or invalid token
content: { }
'500':
"$ref": "#/components/responses/internal_error"
/api/users/invites/{token}/accept:
post:
summary: Accept an invite
description: Accepts an invite and creates the user with the provided password. This endpoint is unauthenticated and protected by the token itself.
tags: [ Users ]
security: []
parameters:
- in: path
name: token
required: true
schema:
type: string
description: The invite token
requestBody:
description: Password to set for the new user
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/UserInviteAcceptRequest'
responses:
'200':
description: Invite accepted successfully
content:
application/json:
schema:
$ref: '#/components/schemas/UserInviteAcceptResponse'
'400':
"$ref": "#/components/responses/bad_request"
'404':
description: Invite not found or invalid token
content: { }
'412':
description: Precondition failed - embedded IdP is not enabled or invite expired
content: { }
'422':
"$ref": "#/components/responses/validation_failed"
'500':
"$ref": "#/components/responses/internal_error"
/api/peers:
get:
summary: List all Peers

View File

@@ -123,6 +123,10 @@ const (
EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add"
EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete"
EventActivityCodeUserInvite EventActivityCode = "user.invite"
EventActivityCodeUserInviteLinkAccept EventActivityCode = "user.invite.link.accept"
EventActivityCodeUserInviteLinkCreate EventActivityCode = "user.invite.link.create"
EventActivityCodeUserInviteLinkDelete EventActivityCode = "user.invite.link.delete"
EventActivityCodeUserInviteLinkRegenerate EventActivityCode = "user.invite.link.regenerate"
EventActivityCodeUserJoin EventActivityCode = "user.join"
EventActivityCodeUserPasswordChange EventActivityCode = "user.password.change"
EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete"
@@ -870,6 +874,21 @@ type InstanceStatus struct {
SetupRequired bool `json:"setup_required"`
}
// InstanceVersionInfo Version information for NetBird components
type InstanceVersionInfo struct {
// DashboardAvailableVersion The latest available version of the dashboard (from GitHub releases)
DashboardAvailableVersion *string `json:"dashboard_available_version,omitempty"`
// ManagementAvailableVersion The latest available version of the management server (from GitHub releases)
ManagementAvailableVersion *string `json:"management_available_version,omitempty"`
// ManagementCurrentVersion The current running version of the management server
ManagementCurrentVersion string `json:"management_current_version"`
// ManagementUpdateAvailable Indicates if a newer management version is available
ManagementUpdateAvailable bool `json:"management_update_available"`
}
// JobRequest defines model for JobRequest.
type JobRequest struct {
Workload WorkloadRequest `json:"workload"`
@@ -2166,6 +2185,99 @@ type UserCreateRequest struct {
Role string `json:"role"`
}
// UserInvite A user invite
type UserInvite struct {
// AutoGroups Group IDs to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// CreatedAt Invite creation time
CreatedAt time.Time `json:"created_at"`
// Email User's email address
Email string `json:"email"`
// Expired Whether the invite has expired
Expired bool `json:"expired"`
// ExpiresAt Invite expiration time
ExpiresAt time.Time `json:"expires_at"`
// Id Invite ID
Id string `json:"id"`
// InviteToken The invite link to be shared with the user. Only returned when the invite is created or regenerated.
InviteToken *string `json:"invite_token,omitempty"`
// Name User's full name
Name string `json:"name"`
// Role User's NetBird account role
Role string `json:"role"`
}
// UserInviteAcceptRequest Request to accept an invite and set password
type UserInviteAcceptRequest struct {
// Password The password the user wants to set. Must be at least 8 characters long and contain at least one uppercase letter, one digit, and one special character (any character that is not a letter or digit, including spaces).
Password string `json:"password"`
}
// UserInviteAcceptResponse Response after accepting an invite
type UserInviteAcceptResponse struct {
// Success Whether the invite was accepted successfully
Success bool `json:"success"`
}
// UserInviteCreateRequest Request to create a user invite link
type UserInviteCreateRequest struct {
// AutoGroups Group IDs to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// Email User's email address
Email string `json:"email"`
// ExpiresIn Invite expiration time in seconds (default 72 hours)
ExpiresIn *int `json:"expires_in,omitempty"`
// Name User's full name
Name string `json:"name"`
// Role User's NetBird account role
Role string `json:"role"`
}
// UserInviteInfo Public information about an invite
type UserInviteInfo struct {
// Email User's email address
Email string `json:"email"`
// ExpiresAt Invite expiration time
ExpiresAt time.Time `json:"expires_at"`
// InvitedBy Name of the user who sent the invite
InvitedBy string `json:"invited_by"`
// Name User's full name
Name string `json:"name"`
// Valid Whether the invite is still valid (not expired)
Valid bool `json:"valid"`
}
// UserInviteRegenerateRequest Request to regenerate an invite link
type UserInviteRegenerateRequest struct {
// ExpiresIn Invite expiration time in seconds (default 72 hours)
ExpiresIn *int `json:"expires_in,omitempty"`
}
// UserInviteRegenerateResponse Response after regenerating an invite
type UserInviteRegenerateResponse struct {
// InviteExpiresAt New invite expiration time
InviteExpiresAt time.Time `json:"invite_expires_at"`
// InviteToken The new invite token
InviteToken string `json:"invite_token"`
}
// UserPermissions defines model for UserPermissions.
type UserPermissions struct {
// IsRestricted Indicates whether this User's Peers view is restricted
@@ -2418,6 +2530,15 @@ type PutApiSetupKeysKeyIdJSONRequestBody = SetupKeyRequest
// PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType.
type PostApiUsersJSONRequestBody = UserCreateRequest
// PostApiUsersInvitesJSONRequestBody defines body for PostApiUsersInvites for application/json ContentType.
type PostApiUsersInvitesJSONRequestBody = UserInviteCreateRequest
// PostApiUsersInvitesInviteIdRegenerateJSONRequestBody defines body for PostApiUsersInvitesInviteIdRegenerate for application/json ContentType.
type PostApiUsersInvitesInviteIdRegenerateJSONRequestBody = UserInviteRegenerateRequest
// PostApiUsersInvitesTokenAcceptJSONRequestBody defines body for PostApiUsersInvitesTokenAccept for application/json ContentType.
type PostApiUsersInvitesTokenAcceptJSONRequestBody = UserInviteAcceptRequest
// PutApiUsersUserIdJSONRequestBody defines body for PutApiUsersUserId for application/json ContentType.
type PutApiUsersUserIdJSONRequestBody = UserRequest