[management] Add account onboarding (#4084)

This PR introduces a new onboarding feature to handle such flows in the dashboard by defining an AccountOnboarding model, persisting it in the store, exposing CRUD operations in the manager and HTTP handlers, and updating API schemas and tests accordingly.

Add AccountOnboarding struct and embed it in Account
Extend Store and DefaultAccountManager with onboarding methods and SQL migrations
Update HTTP handlers, API types, OpenAPI spec, and add end-to-end tests
This commit is contained in:
Maycon Santos
2025-07-03 09:01:32 +02:00
committed by GitHub
parent 551cb4e467
commit 2c81cf2c1e
14 changed files with 476 additions and 103 deletions

View File

@@ -1204,6 +1204,71 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID)
}
// GetAccountOnboarding retrieves the onboarding information for a specific account.
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err)
return nil, err
}
if onboarding == nil {
onboarding = &types.AccountOnboarding{
AccountID: accountID,
}
}
return onboarding, nil
}
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
return nil, fmt.Errorf("failed to get account onboarding: %w", err)
}
if oldOnboarding == nil {
oldOnboarding = &types.AccountOnboarding{
AccountID: accountID,
}
}
if newOnboarding == nil {
return oldOnboarding, nil
}
if oldOnboarding.IsEqual(*newOnboarding) {
log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID)
return oldOnboarding, nil
}
newOnboarding.AccountID = accountID
err = am.Store.SaveAccountOnboarding(ctx, newOnboarding)
if err != nil {
return nil, fmt.Errorf("failed to update account onboarding: %w", err)
}
return newOnboarding, nil
}
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
if userAuth.UserId == "" {
return "", "", errors.New(emptyUserID)
@@ -1726,6 +1791,10 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: true,
},
Onboarding: types.AccountOnboarding{
OnboardingFlowPending: true,
SignupFormPending: true,
},
}
if err := acc.AddAllGroup(disableDefaultPolicy); err != nil {

View File

@@ -39,6 +39,7 @@ type Manager interface {
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error)
GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error)
GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error)
AccountExists(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
@@ -89,6 +90,7 @@ type Manager interface {
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)

View File

@@ -3448,3 +3448,74 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}
})
}
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
require.NoError(t, err)
t.Run("should return account onboarding when onboarding exist", func(t *testing.T) {
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
require.NoError(t, err)
require.NotNil(t, onboarding)
assert.Equal(t, account.Id, onboarding.AccountID)
assert.Equal(t, true, onboarding.OnboardingFlowPending)
assert.Equal(t, true, onboarding.SignupFormPending)
if onboarding.UpdatedAt.IsZero() {
t.Errorf("Onboarding was not retrieved from the store")
}
})
t.Run("should return account onboarding when onboard don't exist", func(t *testing.T) {
account.Id = "with-zero-onboarding"
account.Onboarding = types.AccountOnboarding{}
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
require.NoError(t, err)
require.NotNil(t, onboarding)
_, err = manager.Store.GetAccountOnboarding(context.Background(), account.Id)
require.Error(t, err, "should return error when onboarding is not set")
})
}
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
require.NoError(t, err)
onboarding := &types.AccountOnboarding{
OnboardingFlowPending: true,
SignupFormPending: true,
}
t.Run("update onboarding with no change", func(t *testing.T) {
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
require.NoError(t, err)
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
if updated.UpdatedAt.IsZero() {
t.Errorf("Onboarding was updated in the store")
}
})
onboarding.OnboardingFlowPending = false
onboarding.SignupFormPending = false
t.Run("update onboarding", func(t *testing.T) {
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
require.NoError(t, err)
require.NotNil(t, updated)
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
})
t.Run("update onboarding with no onboarding", func(t *testing.T) {
_, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil)
require.NoError(t, err)
})
}

View File

@@ -60,6 +60,8 @@ components:
description: Account creator
type: string
example: google-oauth2|277474792786460067937
onboarding:
$ref: '#/components/schemas/AccountOnboarding'
required:
- id
- settings
@@ -67,6 +69,21 @@ components:
- domain_category
- created_at
- created_by
- onboarding
AccountOnboarding:
type: object
properties:
signup_form_pending:
description: Indicates whether the account signup form is pending
type: boolean
example: true
onboarding_flow_pending:
description: Indicates whether the account onboarding flow is pending
type: boolean
example: false
required:
- signup_form_pending
- onboarding_flow_pending
AccountSettings:
type: object
properties:
@@ -153,6 +170,8 @@ components:
properties:
settings:
$ref: '#/components/schemas/AccountSettings'
onboarding:
$ref: '#/components/schemas/AccountOnboarding'
required:
- settings
User:

View File

@@ -251,6 +251,7 @@ type Account struct {
// Id Account ID
Id string `json:"id"`
Onboarding AccountOnboarding `json:"onboarding"`
Settings AccountSettings `json:"settings"`
}
@@ -266,8 +267,18 @@ type AccountExtraSettings struct {
PeerApprovalEnabled bool `json:"peer_approval_enabled"`
}
// AccountOnboarding defines model for AccountOnboarding.
type AccountOnboarding struct {
// OnboardingFlowPending Indicates whether the account onboarding flow is pending
OnboardingFlowPending bool `json:"onboarding_flow_pending"`
// SignupFormPending Indicates whether the account signup form is pending
SignupFormPending bool `json:"signup_form_pending"`
}
// AccountRequest defines model for AccountRequest.
type AccountRequest struct {
Onboarding *AccountOnboarding `json:"onboarding,omitempty"`
Settings AccountSettings `json:"settings"`
}

View File

@@ -59,7 +59,13 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
return
}
resp := toAccountResponse(accountID, settings, meta)
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(accountID, settings, meta, onboarding)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
@@ -126,6 +132,20 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
}
var onboarding *types.AccountOnboarding
if req.Onboarding != nil {
onboarding = &types.AccountOnboarding{
OnboardingFlowPending: req.Onboarding.OnboardingFlowPending,
SignupFormPending: req.Onboarding.SignupFormPending,
}
}
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -138,7 +158,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
return
}
resp := toAccountResponse(accountID, updatedSettings, meta)
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
util.WriteJSONObject(r.Context(), w, &resp)
}
@@ -167,7 +187,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta) *api.Account {
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
@@ -188,6 +208,11 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
DnsDomain: &settings.DNSDomain,
}
apiOnboarding := api.AccountOnboarding{
OnboardingFlowPending: onboarding.OnboardingFlowPending,
SignupFormPending: onboarding.SignupFormPending,
}
if settings.Extra != nil {
apiSettings.Extra = &api.AccountExtraSettings{
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
@@ -203,5 +228,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
CreatedBy: meta.CreatedBy,
Domain: meta.Domain,
DomainCategory: meta.DomainCategory,
Onboarding: apiOnboarding,
}
}

View File

@@ -54,6 +54,18 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
GetAccountMetaFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
return account.GetMeta(), nil
},
GetAccountOnboardingFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
return &types.AccountOnboarding{
OnboardingFlowPending: true,
SignupFormPending: true,
}, nil
},
UpdateAccountOnboardingFunc: func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
return &types.AccountOnboarding{
OnboardingFlowPending: true,
SignupFormPending: true,
}, nil
},
},
settingsManager: settingsMockManager,
}
@@ -117,7 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedBody: true,
requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true}}"),
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000,
@@ -139,7 +151,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedBody: true,
requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"),
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000,
@@ -161,7 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedBody: true,
requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"),
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{
PeerLoginExpiration: 554400,
@@ -178,12 +190,34 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedArray: false,
expectedID: accountID,
},
{
name: "PutAccount OK without onboarding",
expectedBody: true,
requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"),
expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000,
PeerLoginExpirationEnabled: false,
GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr("roles"),
JwtGroupsEnabled: br(true),
JwtAllowGroups: &[]string{"test"},
RegularUsersViewBlocked: true,
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
},
expectedArray: false,
expectedID: accountID,
},
{
name: "Update account failure with high peer_login_expiration more than 180 days",
expectedBody: true,
requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true}}"),
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
expectedStatus: http.StatusUnprocessableEntity,
expectedArray: false,
},
@@ -192,7 +226,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedBody: true,
requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true}}"),
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
expectedStatus: http.StatusUnprocessableEntity,
expectedArray: false,
},

View File

@@ -117,7 +117,8 @@ type MockAccountManager struct {
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error)
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
}
@@ -814,6 +815,22 @@ func (am *MockAccountManager) GetAccountMeta(ctx context.Context, accountID stri
return nil, status.Errorf(codes.Unimplemented, "method GetAccountMeta is not implemented")
}
// GetAccountOnboarding mocks GetAccountOnboarding of the AccountManager interface
func (am *MockAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
if am.GetAccountOnboardingFunc != nil {
return am.GetAccountOnboardingFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountOnboarding is not implemented")
}
// UpdateAccountOnboarding mocks UpdateAccountOnboarding of the AccountManager interface
func (am *MockAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID string, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
if am.UpdateAccountOnboardingFunc != nil {
return am.UpdateAccountOnboardingFunc(ctx, accountID, userID, onboarding)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdateAccountOnboarding is not implemented")
}
// GetUserByID mocks GetUserByID of the AccountManager interface
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
if am.GetUserByIDFunc != nil {

View File

@@ -90,6 +90,11 @@ func NewAccountNotFoundError(accountKey string) error {
return Errorf(NotFound, "account not found: %s", accountKey)
}
// NewAccountOnboardingNotFoundError creates a new Error with NotFound type for a missing account onboarding
func NewAccountOnboardingNotFoundError(accountKey string) error {
return Errorf(NotFound, "account onboarding not found: %s", accountKey)
}
// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account
func NewPeerNotPartOfAccountError() error {
return Errorf(PermissionDenied, "peer is not part of this account")

View File

@@ -99,7 +99,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -728,6 +728,32 @@ func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStren
return &accountMeta, nil
}
// GetAccountOnboarding retrieves the onboarding information for a specific account.
func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) {
var accountOnboarding types.AccountOnboarding
result := s.db.Model(&accountOnboarding).First(&accountOnboarding, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountOnboardingNotFoundError(accountID)
}
log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error)
return nil, status.NewGetAccountFromStoreError(result.Error)
}
return &accountOnboarding, nil
}
// SaveAccountOnboarding updates the onboarding information for a specific account.
func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error {
result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(onboarding)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)
return status.Errorf(status.Internal, "error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)
}
return nil
}
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now()
defer func() {

View File

@@ -354,9 +354,16 @@ func TestSqlite_DeleteAccount(t *testing.T) {
t.Errorf("expecting 1 Accounts to be stored after SaveAccount()")
}
o, err := store.GetAccountOnboarding(context.Background(), account.Id)
require.NoError(t, err)
require.Equal(t, o.AccountID, account.Id)
err = store.DeleteAccount(context.Background(), account)
require.NoError(t, err)
_, err = store.GetAccountOnboarding(context.Background(), account.Id)
require.Error(t, err, "expecting error after removing DeleteAccount when getting onboarding")
if len(store.GetAllAccounts(context.Background())) != 0 {
t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()")
}
@@ -414,12 +421,21 @@ func Test_GetAccount(t *testing.T) {
account, err := store.GetAccount(context.Background(), id)
require.NoError(t, err)
require.Equal(t, id, account.Id, "account id should match")
require.Equal(t, false, account.Onboarding.OnboardingFlowPending)
id = "9439-34653001fc3b-bf1c8084-ba50-4ce7"
account, err = store.GetAccount(context.Background(), id)
require.NoError(t, err)
require.Equal(t, id, account.Id, "account id should match")
require.Equal(t, true, account.Onboarding.OnboardingFlowPending)
_, err = store.GetAccount(context.Background(), "non-existing-account")
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
})
}
@@ -2096,6 +2112,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
},
Onboarding: types.AccountOnboarding{SignupFormPending: true, OnboardingFlowPending: true},
}
if err := acc.AddAllGroup(false); err != nil {
@@ -3440,6 +3457,63 @@ func TestSqlStore_GetAccountMeta(t *testing.T) {
require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC())
}
func TestSqlStore_GetAccountOnboarding(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7"
a, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
t.Logf("Onboarding: %+v", a.Onboarding)
err = store.SaveAccount(context.Background(), a)
require.NoError(t, err)
onboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
require.NoError(t, err)
require.NotNil(t, onboarding)
require.Equal(t, accountID, onboarding.AccountID)
require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), onboarding.CreatedAt.UTC())
}
func TestSqlStore_SaveAccountOnboarding(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
t.Run("New onboarding should be saved correctly", func(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
onboarding := &types.AccountOnboarding{
AccountID: accountID,
SignupFormPending: true,
OnboardingFlowPending: true,
}
err = store.SaveAccountOnboarding(context.Background(), onboarding)
require.NoError(t, err)
savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
require.NoError(t, err)
require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending)
require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending)
})
t.Run("Existing onboarding should be updated correctly", func(t *testing.T) {
accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7"
onboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
require.NoError(t, err)
onboarding.OnboardingFlowPending = !onboarding.OnboardingFlowPending
onboarding.SignupFormPending = !onboarding.SignupFormPending
err = store.SaveAccountOnboarding(context.Background(), onboarding)
require.NoError(t, err)
savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID)
require.NoError(t, err)
require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending)
require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending)
})
}
func TestSqlStore_GetAnyAccountID(t *testing.T) {
t.Run("should return account ID when accounts exist", func(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())

View File

@@ -52,6 +52,7 @@ type Store interface {
GetAllAccounts(ctx context.Context) []*types.Account
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error)
GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error)
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*types.Account, error)
@@ -74,6 +75,7 @@ type Store interface {
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)

View File

@@ -1,4 +1,5 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `account_onboardings` (`account_id` text, `created_at` datetime,`updated_at` datetime, `onboarding_flow_pending` numeric, `signup_form_pending` numeric, PRIMARY KEY (`account_id`));
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
@@ -38,7 +39,8 @@ CREATE INDEX `idx_networks_id` ON `networks`(`id`);
CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO accounts VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','90d6-0242ac120003-edafee4e-63fb-11ec','2024-10-02 16:01:38.210000+02:00','test2.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO account_onboardings VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','2024-10-02 16:01:38.210000+02:00','2021-08-19 20:46:20.005936822+02:00',1,0);INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0);
INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');

View File

@@ -83,10 +83,10 @@ type Account struct {
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"`
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
}
// Subclass used in gorm to only load network and not whole account
@@ -104,6 +104,20 @@ type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}
type AccountOnboarding struct {
AccountID string `gorm:"primaryKey"`
OnboardingFlowPending bool
SignupFormPending bool
CreatedAt time.Time
UpdatedAt time.Time
}
// IsEqual compares two AccountOnboarding objects and returns true if they are equal
func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
return o.OnboardingFlowPending == onboarding.OnboardingFlowPending &&
o.SignupFormPending == onboarding.SignupFormPending
}
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
// from the ACL peers that have distribution groups associated with the peer ID.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
@@ -866,6 +880,7 @@ func (a *Account) Copy() *Account {
Networks: nets,
NetworkRouters: networkRouters,
NetworkResources: networkResources,
Onboarding: a.Onboarding,
}
}