Compare commits

...

20 Commits

Author SHA1 Message Date
Maycon Santos
3613a70c8c Prevent return extra userData
We need to prevent returning extra userdata
to avoid endless cache reload
2023-09-30 21:03:12 +02:00
bcmmbaga
e04fb2e11a Initialize a user data slice with defined capacity 2023-09-27 12:02:05 +03:00
bcmmbaga
1f8407a580 Merge remote-tracking branch 'origin/idp-user-cache' into idp-user-cache
# Conflicts:
#	management/server/account.go
2023-09-27 12:01:22 +03:00
bcmmbaga
7c5c9a2a42 Initialize a user data slice with defined capacity 2023-09-27 12:00:28 +03:00
bcmmbaga
cebf9d93a9 Initialize a user data slice with defined capacity 2023-09-27 10:45:21 +03:00
bcmmbaga
b155ee28f2 Remove unused function and variables 2023-09-26 15:13:07 +03:00
bcmmbaga
baf2546aa4 Merge branch 'main' into idp-user-cache 2023-09-26 15:03:48 +03:00
bcmmbaga
1c4e6b2413 Refactor Google Workspace IdP manager 2023-09-26 13:40:23 +03:00
bcmmbaga
de1fa8b26b Remove unused types declarations 2023-09-26 12:28:06 +03:00
bcmmbaga
c10ebdb227 Refactor Azure IDP manager 2023-09-26 12:27:23 +03:00
bcmmbaga
5adebea24a Refactor Okta IDP manager 2023-09-26 12:00:36 +03:00
bcmmbaga
7706319aab Refactor Keycloak IDP manager 2023-09-25 19:22:48 +03:00
bcmmbaga
62d585302f Refactor Zitadel IDP manager 2023-09-25 18:29:19 +03:00
bcmmbaga
1c7d4e9f8a Merge branch 'main' into idp-user-cache 2023-09-25 10:58:15 +03:00
bcmmbaga
7aa72f6ab6 cleanup 2023-09-22 17:11:02 +03:00
bcmmbaga
a3f6de0115 Refactor Authentik IdP manager 2023-09-22 15:20:32 +03:00
bcmmbaga
519f18bbad Add compatibility for the IDP lacking AppMetadata update capabilities 2023-09-22 15:19:07 +03:00
bcmmbaga
90272718db Merge branch 'main' into idp-user-cache 2023-09-22 13:38:04 +03:00
bcmmbaga
8293c95b1b wip: load user account into cache for idp with no GetAccount support 2023-09-14 17:12:35 +03:00
bcmmbaga
50ecf6f4da wip: Handle user metadata without transmitting them to Identity Provider 2023-09-14 12:49:35 +03:00
12 changed files with 220 additions and 1671 deletions

View File

@@ -577,8 +577,8 @@ func (a *Account) Copy() *Account {
}
routes := map[string]*route.Route{}
for id, route := range a.Routes {
routes[id] = route.Copy()
for id, r := range a.Routes {
routes[id] = r.Copy()
}
nsGroups := map[string]*nbdns.NameServerGroup{}
@@ -928,6 +928,27 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
return err
}
// If the Identity Provider does not support writing AppMetadata,
// in cases like this, we expect it to return all users in an "unset" field.
// We iterate over the users in the "unset" field, look up their AccountID in our store, and
// update their AppMetadata with the AccountID.
if unsetData, ok := userData["unset"]; ok {
for _, user := range unsetData {
accountID, err := am.Store.GetAccountByUser(user.ID)
if err == nil {
data := userData[accountID.Id]
if data == nil {
data = make([]*idp.UserData, 0, 1)
}
user.AppMetadata.WTAccountID = accountID.Id
userData[accountID.Id] = append(data, user)
}
}
}
delete(userData, "unset")
for accountID, users := range userData {
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
if err != nil {
@@ -994,7 +1015,36 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account
func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interface{}) ([]*idp.UserData, error) {
log.Debugf("account %s not found in cache, reloading", accountID)
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
accountIDString := fmt.Sprintf("%v", accountID)
account, err := am.Store.GetAccount(accountIDString)
if err != nil {
return nil, err
}
userData, err := am.idpManager.GetAccount(accountIDString)
if err != nil {
return nil, err
}
dataMap := make(map[string]*idp.UserData, len(userData))
for _, datum := range userData {
dataMap[datum.ID] = datum
}
matchedUserData := make([]*idp.UserData, 0)
for _, user := range account.Users {
if user.IsServiceUser {
continue
}
datum, ok := dataMap[user.Id]
if !ok {
log.Warnf("user %s not found in IDP", user.Id)
continue
}
matchedUserData = append(matchedUserData, datum)
}
return matchedUserData, nil
}
func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountID string) (*idp.UserData, error) {

View File

@@ -210,47 +210,7 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (am *AuthentikManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
ctx, err := am.authenticationContext()
if err != nil {
return err
}
userPk, err := strconv.ParseInt(userID, 10, 32)
if err != nil {
return err
}
var pendingInvite bool
if appMetadata.WTPendingInvite != nil {
pendingInvite = *appMetadata.WTPendingInvite
}
patchedUserReq := api.PatchedUserRequest{
Attributes: map[string]interface{}{
wtAccountID: appMetadata.WTAccountID,
wtPendingInvite: pendingInvite,
},
}
_, resp, err := am.apiClient.CoreApi.CoreUsersPartialUpdate(ctx, int32(userPk)).
PatchedUserRequest(patchedUserReq).
Execute()
if err != nil {
return err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusOK {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to update user %s, statusCode %d", userID, resp.StatusCode)
}
func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
@@ -283,7 +243,10 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
}
return parseAuthentikUser(*user)
userData := parseAuthentikUser(*user)
userData.AppMetadata = appMetadata
return userData, nil
}
// GetAccount returns all the users for a given profile.
@@ -293,8 +256,7 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
return nil, err
}
accountFilter := fmt.Sprintf("{%q:%q}", wtAccountID, accountID)
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Attributes(accountFilter).Execute()
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute()
if err != nil {
return nil, err
}
@@ -313,10 +275,9 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
users := make([]*UserData, 0)
for _, user := range userList.Results {
userData, err := parseAuthentikUser(user)
if err != nil {
return nil, err
}
userData := parseAuthentikUser(user)
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
@@ -350,65 +311,18 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
indexedUsers := make(map[string][]*UserData)
for _, user := range userList.Results {
userData, err := parseAuthentikUser(user)
if err != nil {
return nil, err
}
userData := parseAuthentikUser(user)
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
accountID := "unset"
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
return indexedUsers, nil
}
// CreateUser creates a new user in authentik Idp and sends an invitation.
func (am *AuthentikManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
ctx, err := am.authenticationContext()
if err != nil {
return nil, err
}
groupID, err := am.getUserGroupByName("netbird")
if err != nil {
return nil, err
}
defaultBoolValue := true
createUserRequest := api.UserRequest{
Email: &email,
Name: name,
IsActive: &defaultBoolValue,
Groups: []string{groupID},
Username: email,
Attributes: map[string]interface{}{
wtAccountID: accountID,
wtPendingInvite: &defaultBoolValue,
},
}
user, resp, err := am.apiClient.CoreApi.CoreUsersCreate(ctx).UserRequest(createUserRequest).Execute()
if err != nil {
return nil, err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
if resp.StatusCode != http.StatusCreated {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
return parseAuthentikUser(*user)
func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
@@ -438,11 +352,7 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
users := make([]*UserData, 0)
for _, user := range userList.Results {
userData, err := parseAuthentikUser(user)
if err != nil {
return nil, err
}
users = append(users, userData)
users = append(users, parseAuthentikUser(user))
}
return users, nil
@@ -501,64 +411,10 @@ func (am *AuthentikManager) authenticationContext() (context.Context, error) {
return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil
}
// getUserGroupByName retrieves the user group for assigning new users.
// If the group is not found, a new group with the specified name will be created.
func (am *AuthentikManager) getUserGroupByName(name string) (string, error) {
ctx, err := am.authenticationContext()
if err != nil {
return "", err
}
groupList, resp, err := am.apiClient.CoreApi.CoreGroupsList(ctx).Name(name).Execute()
if err != nil {
return "", err
}
defer resp.Body.Close()
if groupList != nil {
if len(groupList.Results) > 0 {
return groupList.Results[0].Pk, nil
}
}
createGroupRequest := api.GroupRequest{Name: name}
group, resp, err := am.apiClient.CoreApi.CoreGroupsCreate(ctx).GroupRequest(createGroupRequest).Execute()
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
return "", fmt.Errorf("unable to create user group, statusCode: %d", resp.StatusCode)
}
return group.Pk, nil
}
func parseAuthentikUser(user api.User) (*UserData, error) {
var attributes struct {
AccountID string `json:"wt_account_id"`
PendingInvite bool `json:"wt_pending_invite"`
}
helper := JsonParser{}
buf, err := helper.Marshal(user.Attributes)
if err != nil {
return nil, err
}
err = helper.Unmarshal(buf, &attributes)
if err != nil {
return nil, err
}
func parseAuthentikUser(user api.User) *UserData {
return &UserData{
Email: *user.Email,
Name: user.Name,
ID: strconv.FormatInt(int64(user.Pk), 10),
AppMetadata: AppMetadata{
WTAccountID: attributes.AccountID,
WTPendingInvite: &attributes.PendingInvite,
},
}, nil
}
}

View File

@@ -1,7 +1,6 @@
package idp
import (
"encoding/json"
"fmt"
"io"
"net/http"
@@ -11,18 +10,12 @@ import (
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
// azure extension properties template
wtAccountIDTpl = "extension_%s_wt_account_id"
wtPendingInviteTpl = "extension_%s_wt_pending_invite"
profileFields = "id,displayName,mail,userPrincipalName"
extensionFields = "id,name,targetObjects"
)
const profileFields = "id,displayName,mail,userPrincipalName"
// AzureManager azure manager client instance.
type AzureManager struct {
@@ -58,21 +51,6 @@ type AzureCredentials struct {
// azureProfile represents an azure user profile.
type azureProfile map[string]any
// passwordProfile represent authentication method for,
// newly created user profile.
type passwordProfile struct {
ForceChangePasswordNextSignIn bool `json:"forceChangePasswordNextSignIn"`
Password string `json:"password"`
}
// azureExtension represent custom attribute,
// that can be added to user objects in Azure Active Directory (AD).
type azureExtension struct {
Name string `json:"name"`
DataType string `json:"dataType"`
TargetObjects []string `json:"targetObjects"`
}
// NewAzureManager creates a new instance of the AzureManager.
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
@@ -115,7 +93,7 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
appMetrics: appMetrics,
}
manager := &AzureManager{
return &AzureManager{
ObjectID: config.ObjectID,
ClientID: config.ClientID,
GraphAPIEndpoint: config.GraphAPIEndpoint,
@@ -123,14 +101,7 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}
err := manager.configureAppMetadata()
if err != nil {
return nil, err
}
return manager, nil
}, nil
}
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
@@ -236,44 +207,14 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
}
// CreateUser creates a new user in azure AD Idp.
func (am *AzureManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
payload, err := buildAzureCreateUserRequestPayload(email, name, accountID, am.ClientID)
if err != nil {
return nil, err
}
body, err := am.post("users", payload)
if err != nil {
return nil, err
}
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountCreateUser()
}
var profile azureProfile
err = am.helper.Unmarshal(body, &profile)
if err != nil {
return nil, err
}
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
profile[wtAccountIDField] = accountID
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
profile[wtPendingInviteField] = true
return profile.userData(am.ClientID), nil
func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserDataByID requests user data from keycloak via ID.
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{}
q.Add("$select", selectFields)
q.Add("$select", profileFields)
body, err := am.get("users/"+userID, q)
if err != nil {
@@ -290,18 +231,17 @@ func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata)
return nil, err
}
return profile.userData(am.ClientID), nil
userData := profile.userData()
userData.AppMetadata = appMetadata
return userData, nil
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{}
q.Add("$select", selectFields)
q.Add("$select", profileFields)
body, err := am.get("users/"+email, q)
if err != nil {
@@ -319,20 +259,15 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
}
users := make([]*UserData, 0)
users = append(users, profile.userData(am.ClientID))
users = append(users, profile.userData())
return users, nil
}
// GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{}
q.Add("$select", selectFields)
q.Add("$filter", fmt.Sprintf("%s eq '%s'", wtAccountIDField, accountID))
q.Add("$select", profileFields)
body, err := am.get("users", q)
if err != nil {
@@ -351,7 +286,10 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
users := make([]*UserData, 0)
for _, profile := range profiles.Value {
users = append(users, profile.userData(am.ClientID))
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
return users, nil
@@ -360,12 +298,8 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
q := url.Values{}
q.Add("$select", selectFields)
q.Add("$select", profileFields)
body, err := am.get("users", q)
if err != nil {
@@ -384,67 +318,17 @@ func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles.Value {
userData := profile.userData(am.ClientID)
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
userData := profile.userData()
accountID := "unset"
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
return indexedUsers, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID.
func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return err
}
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
data, err := am.helper.Marshal(map[string]any{
wtAccountIDField: appMetadata.WTAccountID,
wtPendingInviteField: appMetadata.WTPendingInvite,
})
if err != nil {
return err
}
payload := strings.NewReader(string(data))
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, userID)
req, err := http.NewRequest(http.MethodPatch, reqURL, payload)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("updating idp metadata for user %s", userID)
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
}
func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
@@ -454,7 +338,7 @@ func (am *AzureManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Azure
// DeleteUser from Azure.
func (am *AzureManager) DeleteUser(userID string) error {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
@@ -491,81 +375,6 @@ func (am *AzureManager) DeleteUser(userID string) error {
return nil
}
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
q := url.Values{}
q.Add("$select", extensionFields)
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
body, err := am.get(resource, q)
if err != nil {
return nil, err
}
var extensions struct{ Value []azureExtension }
err = am.helper.Unmarshal(body, &extensions)
if err != nil {
return nil, err
}
return extensions.Value, nil
}
func (am *AzureManager) createUserExtension(name string) (*azureExtension, error) {
extension := azureExtension{
Name: name,
DataType: "string",
TargetObjects: []string{"User"},
}
payload, err := am.helper.Marshal(extension)
if err != nil {
return nil, err
}
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
body, err := am.post(resource, string(payload))
if err != nil {
return nil, err
}
var userExtension azureExtension
err = am.helper.Unmarshal(body, &userExtension)
if err != nil {
return nil, err
}
return &userExtension, nil
}
// configureAppMetadata sets up app metadata extensions if they do not exists.
func (am *AzureManager) configureAppMetadata() error {
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
extensions, err := am.getUserExtensions()
if err != nil {
return err
}
// If the wt_account_id extension does not already exist, create it.
if !hasExtension(extensions, wtAccountIDField) {
_, err = am.createUserExtension(wtAccountID)
if err != nil {
return err
}
}
// If the wt_pending_invite extension does not already exist, create it.
if !hasExtension(extensions, wtPendingInviteField) {
_, err = am.createUserExtension(wtPendingInvite)
if err != nil {
return err
}
}
return nil
}
// get perform Get requests.
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate()
@@ -602,44 +411,8 @@ func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
return io.ReadAll(resp.Body)
}
// post perform Post requests.
func (am *AzureManager) post(resource string, body string) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/%s", am.GraphAPIEndpoint, resource)
req, err := http.NewRequest(http.MethodPost, reqURL, strings.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := am.httpClient.Do(req)
if err != nil {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// userData construct user data from keycloak profile.
func (ap azureProfile) userData(clientID string) *UserData {
func (ap azureProfile) userData() *UserData {
id, ok := ap["id"].(string)
if !ok {
id = ""
@@ -655,66 +428,9 @@ func (ap azureProfile) userData(clientID string) *UserData {
name = ""
}
accountIDField := extensionName(wtAccountIDTpl, clientID)
accountID, ok := ap[accountIDField].(string)
if !ok {
accountID = ""
}
pendingInviteField := extensionName(wtPendingInviteTpl, clientID)
pendingInvite, ok := ap[pendingInviteField].(bool)
if !ok {
pendingInvite = false
}
return &UserData{
Email: email,
Name: name,
ID: id,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pendingInvite,
},
}
}
func buildAzureCreateUserRequestPayload(email, name, accountID, clientID string) (string, error) {
wtAccountIDField := extensionName(wtAccountIDTpl, clientID)
wtPendingInviteField := extensionName(wtPendingInviteTpl, clientID)
req := &azureProfile{
"accountEnabled": true,
"displayName": name,
"mailNickName": strings.Join(strings.Split(name, " "), ""),
"userPrincipalName": email,
"passwordProfile": passwordProfile{
ForceChangePasswordNextSignIn: true,
Password: GeneratePassword(8, 1, 1, 1),
},
wtAccountIDField: accountID,
wtPendingInviteField: true,
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}
func extensionName(extensionTpl, clientID string) string {
clientID = strings.ReplaceAll(clientID, "-", "")
return fmt.Sprintf(extensionTpl, clientID)
}
// hasExtension checks whether a given extension by name,
// exists in an list of extensions.
func hasExtension(extensions []azureExtension, name string) bool {
for _, ext := range extensions {
if ext.Name == name {
return true
}
}
return false
}

View File

@@ -8,15 +8,6 @@ import (
"github.com/stretchr/testify/assert"
)
type mockAzureCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockAzureCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestAzureJwtStillValid(t *testing.T) {
type jwtStillValidTest struct {
name string
@@ -124,206 +115,63 @@ func TestAzureAuthenticate(t *testing.T) {
}
}
func TestAzureUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":true}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockAzureCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &AzureManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}
func TestAzureProfile(t *testing.T) {
type azureProfileTest struct {
name string
clientID string
invite bool
inputProfile azureProfile
expectedUserData UserData
}
azureProfileTestCase1 := azureProfileTest{
name: "Good Request",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false,
name: "Good Request",
invite: false,
inputProfile: azureProfile{
"id": "test1",
"displayName": "John Doe",
"userPrincipalName": "test1@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
},
expectedUserData: UserData{
Email: "test1@test.com",
Name: "John Doe",
ID: "test1",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
azureProfileTestCase2 := azureProfileTest{
name: "Missing User ID",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: true,
name: "Missing User ID",
invite: true,
inputProfile: azureProfile{
"displayName": "John Doe",
"userPrincipalName": "test2@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": true,
},
expectedUserData: UserData{
Email: "test2@test.com",
Name: "John Doe",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
azureProfileTestCase3 := azureProfileTest{
name: "Missing User Name",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false,
name: "Missing User Name",
invite: false,
inputProfile: azureProfile{
"id": "test3",
"userPrincipalName": "test3@test.com",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
},
expectedUserData: UserData{
ID: "test3",
Email: "test3@test.com",
AppMetadata: AppMetadata{
WTAccountID: "1",
},
},
}
azureProfileTestCase4 := azureProfileTest{
name: "Missing Extension Fields",
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
invite: false,
inputProfile: azureProfile{
"id": "test4",
"displayName": "John Doe",
"userPrincipalName": "test4@test.com",
},
expectedUserData: UserData{
ID: "test4",
Name: "John Doe",
Email: "test4@test.com",
AppMetadata: AppMetadata{},
},
}
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3, azureProfileTestCase4} {
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3} {
t.Run(testCase.name, func(t *testing.T) {
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
userData := testCase.inputProfile.userData(testCase.clientID)
userData := testCase.inputProfile.userData()
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
})
}
}

View File

@@ -5,15 +5,14 @@ import (
"encoding/base64"
"fmt"
"net/http"
"strings"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/googleapi"
"google.golang.org/api/option"
"github.com/netbirdio/netbird/management/server/telemetry"
)
// GoogleWorkspaceManager Google Workspace manager client instance.
@@ -73,17 +72,13 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
}
service, err := admin.NewService(context.Background(),
option.WithScopes(admin.AdminDirectoryUserScope, admin.AdminDirectoryUserschemaScope),
option.WithScopes(admin.AdminDirectoryUserScope),
option.WithCredentials(adminCredentials),
)
if err != nil {
return nil, err
}
if err = configureAppMetadataSchema(service, config.CustomerID); err != nil {
return nil, err
}
return &GoogleWorkspaceManager{
usersService: service.Users,
CustomerID: config.CustomerID,
@@ -95,27 +90,7 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
metadata, err := gm.helper.Marshal(appMetadata)
if err != nil {
return err
}
user := &admin.User{
CustomSchemas: map[string]googleapi.RawMessage{
"app_metadata": metadata,
},
}
_, err = gm.usersService.Update(userID, user).Do()
if err != nil {
return err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
@@ -130,23 +105,23 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App
gm.appMetrics.IDPMetrics().CountGetUserDataByID()
}
return parseGoogleWorkspaceUser(user)
userData := parseGoogleWorkspaceUser(user)
userData.AppMetadata = appMetadata
return userData, nil
}
// GetAccount returns all the users for a given profile.
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
query := fmt.Sprintf("app_metadata.wt_account_id=\"%s\"", accountID)
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Query(query).Projection("full").Do()
usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do()
if err != nil {
return nil, err
}
usersData := make([]*UserData, 0)
for _, user := range usersList.Users {
userData, err := parseGoogleWorkspaceUser(user)
if err != nil {
return nil, err
}
userData := parseGoogleWorkspaceUser(user)
userData.AppMetadata.WTAccountID = accountID
usersData = append(usersData, userData)
}
@@ -168,61 +143,16 @@ func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, erro
indexedUsers := make(map[string][]*UserData)
for _, user := range usersList.Users {
userData, err := parseGoogleWorkspaceUser(user)
if err != nil {
return nil, err
}
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
accountID := "unset"
indexedUsers[accountID] = append(indexedUsers[accountID], parseGoogleWorkspaceUser(user))
}
return indexedUsers, nil
}
// CreateUser creates a new user in Google Workspace and sends an invitation.
func (gm *GoogleWorkspaceManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
invite := true
metadata := AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
}
username := &admin.UserName{}
fields := strings.Fields(name)
if n := len(fields); n > 0 {
username.GivenName = strings.Join(fields[:n-1], " ")
username.FamilyName = fields[n-1]
}
payload, err := gm.helper.Marshal(metadata)
if err != nil {
return nil, err
}
user := &admin.User{
Name: username,
PrimaryEmail: email,
CustomSchemas: map[string]googleapi.RawMessage{
"app_metadata": payload,
},
Password: GeneratePassword(8, 1, 1, 1),
}
user, err = gm.usersService.Insert(user).Do()
if err != nil {
return nil, err
}
if gm.appMetrics != nil {
gm.appMetrics.IDPMetrics().CountCreateUser()
}
return parseGoogleWorkspaceUser(user)
func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
@@ -237,13 +167,8 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err
gm.appMetrics.IDPMetrics().CountGetUserByEmail()
}
userData, err := parseGoogleWorkspaceUser(user)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
users = append(users, userData)
users = append(users, parseGoogleWorkspaceUser(user))
return users, nil
}
@@ -281,7 +206,6 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
creds, err := google.CredentialsFromJSON(
context.Background(),
decodeKey,
admin.AdminDirectoryUserschemaScope,
admin.AdminDirectoryUserScope,
)
if err == nil {
@@ -294,7 +218,6 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
creds, err = google.FindDefaultCredentials(
context.Background(),
admin.AdminDirectoryUserschemaScope,
admin.AdminDirectoryUserScope,
)
if err != nil {
@@ -304,62 +227,11 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
return creds, nil
}
// configureAppMetadataSchema create a custom schema for managing app metadata fields in Google Workspace.
func configureAppMetadataSchema(service *admin.Service, customerID string) error {
schemaList, err := service.Schemas.List(customerID).Do()
if err != nil {
return err
}
// checks if app_metadata schema is already created
for _, schema := range schemaList.Schemas {
if schema.SchemaName == "app_metadata" {
return nil
}
}
// create new app_metadata schema
appMetadataSchema := &admin.Schema{
SchemaName: "app_metadata",
Fields: []*admin.SchemaFieldSpec{
{
FieldName: "wt_account_id",
FieldType: "STRING",
MultiValued: false,
},
{
FieldName: "wt_pending_invite",
FieldType: "BOOL",
MultiValued: false,
},
},
}
_, err = service.Schemas.Insert(customerID, appMetadataSchema).Do()
if err != nil {
return err
}
return nil
}
// parseGoogleWorkspaceUser parse google user to UserData.
func parseGoogleWorkspaceUser(user *admin.User) (*UserData, error) {
var appMetadata AppMetadata
// Get app metadata from custom schemas
if user.CustomSchemas != nil {
rawMessage := user.CustomSchemas["app_metadata"]
helper := JsonParser{}
if err := helper.Unmarshal(rawMessage, &appMetadata); err != nil {
return nil, err
}
}
func parseGoogleWorkspaceUser(user *admin.User) *UserData {
return &UserData{
ID: user.Id,
Email: user.PrimaryEmail,
Name: user.Name.FullName,
AppMetadata: appMetadata,
}, nil
ID: user.Id,
Email: user.PrimaryEmail,
Name: user.Name.FullName,
}
}

View File

@@ -9,6 +9,14 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
)
type Error struct {
message string
}
func (e *Error) Error() string {
return e.message
}
// Manager idp manager interface
type Manager interface {
UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error

View File

@@ -1,12 +1,10 @@
package idp
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"sync"
@@ -18,11 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
wtAccountID = "wt_account_id"
wtPendingInvite = "wt_pending_invite"
)
// KeycloakManager keycloak manager client instance.
type KeycloakManager struct {
adminEndpoint string
@@ -51,28 +44,10 @@ type KeycloakCredentials struct {
appMetrics telemetry.AppMetrics
}
// keycloakUserCredential describe the authentication method for,
// newly created user profile.
type keycloakUserCredential struct {
Type string `json:"type"`
Value string `json:"value"`
Temporary bool `json:"temporary"`
}
// keycloakUserAttributes holds additional user data fields.
type keycloakUserAttributes map[string][]string
// createUserRequest is a user create request.
type keycloakCreateUserRequest struct {
Email string `json:"email"`
Username string `json:"username"`
Enabled bool `json:"enabled"`
EmailVerified bool `json:"emailVerified"`
Credentials []keycloakUserCredential `json:"credentials"`
Attributes keycloakUserAttributes `json:"attributes"`
}
// keycloakProfile represents an keycloak user profile response.
// keycloakProfile represents a keycloak user profile response.
type keycloakProfile struct {
ID string `json:"id"`
CreatedTimestamp int64 `json:"createdTimestamp"`
@@ -230,62 +205,8 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
}
// CreateUser creates a new user in keycloak Idp and sends an invite.
func (km *KeycloakManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
jwtToken, err := km.credentials.Authenticate()
if err != nil {
return nil, err
}
invite := true
appMetadata := AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
}
payloadString, err := buildKeycloakCreateUserRequestPayload(email, name, appMetadata)
if err != nil {
return nil, err
}
reqURL := fmt.Sprintf("%s/users", km.adminEndpoint)
payload := strings.NewReader(payloadString)
req, err := http.NewRequest(http.MethodPost, reqURL, payload)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountCreateUser()
}
resp, err := km.httpClient.Do(req)
if err != nil {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
locationHeader := resp.Header.Get("location")
userID, err := extractUserIDFromLocationHeader(locationHeader)
if err != nil {
return nil, err
}
return km.GetUserDataByID(userID, appMetadata)
func (km *KeycloakManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
@@ -319,7 +240,7 @@ func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
}
// GetUserDataByID requests user data from keycloak via ID.
func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserData, error) {
body, err := km.get("users/"+userID, nil)
if err != nil {
return nil, err
@@ -338,12 +259,9 @@ func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadat
return profile.userData(), nil
}
// GetAccount returns all the users for a given profile.
// GetAccount returns all the users for a given account profile.
func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
q := url.Values{}
q.Add("q", wtAccountID+":"+accountID)
body, err := km.get("users", q)
profiles, err := km.fetchAllUserProfiles()
if err != nil {
return nil, err
}
@@ -352,15 +270,12 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
km.appMetrics.IDPMetrics().CountGetAccount()
}
profiles := make([]keycloakProfile, 0)
err = km.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles {
users = append(users, profile.userData())
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
return users, nil
@@ -369,15 +284,7 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
totalUsers, err := km.totalUsersCount()
if err != nil {
return nil, err
}
q := url.Values{}
q.Add("max", fmt.Sprint(*totalUsers))
body, err := km.get("users", q)
profiles, err := km.fetchAllUserProfiles()
if err != nil {
return nil, err
}
@@ -386,78 +293,17 @@ func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
km.appMetrics.IDPMetrics().CountGetAllAccounts()
}
profiles := make([]keycloakProfile, 0)
err = km.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles {
userData := profile.userData()
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
accountID := "unset"
indexedUsers[accountID] = append(indexedUsers[accountID], profile.userData())
}
return indexedUsers, nil
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
jwtToken, err := km.credentials.Authenticate()
if err != nil {
return err
}
attrs := keycloakUserAttributes{}
attrs.Set(wtAccountID, appMetadata.WTAccountID)
if appMetadata.WTPendingInvite != nil {
attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite))
} else {
attrs.Set(wtPendingInvite, "false")
}
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, userID)
data, err := km.helper.Marshal(map[string]any{
"attributes": attrs,
})
if err != nil {
return err
}
payload := strings.NewReader(string(data))
req, err := http.NewRequest(http.MethodPut, reqURL, payload)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("updating IdP metadata for user %s", userID)
resp, err := km.httpClient.Do(req)
if err != nil {
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if km.appMetrics != nil {
km.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusNoContent {
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
}
func (km *KeycloakManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
@@ -467,7 +313,7 @@ func (km *KeycloakManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Keycloack
// DeleteUser from Keycloak by user ID.
func (km *KeycloakManager) DeleteUser(userID string) error {
jwtToken, err := km.credentials.Authenticate()
if err != nil {
@@ -475,7 +321,6 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
}
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
@@ -508,32 +353,27 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
return nil
}
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
attrs := keycloakUserAttributes{}
attrs.Set(wtAccountID, appMetadata.WTAccountID)
attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite))
req := &keycloakCreateUserRequest{
Email: email,
Username: name,
Enabled: true,
EmailVerified: true,
Credentials: []keycloakUserCredential{
{
Type: "password",
Value: GeneratePassword(8, 1, 1, 1),
Temporary: false,
},
},
Attributes: attrs,
}
str, err := json.Marshal(req)
func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
totalUsers, err := km.totalUsersCount()
if err != nil {
return "", err
return nil, err
}
return string(str), nil
q := url.Values{}
q.Add("max", fmt.Sprint(*totalUsers))
body, err := km.get("users", q)
if err != nil {
return nil, err
}
profiles := make([]keycloakProfile, 0)
err = km.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
return profiles, nil
}
// get perform Get requests.
@@ -588,53 +428,11 @@ func (km *KeycloakManager) totalUsersCount() (*int, error) {
return &count, nil
}
// extractUserIDFromLocationHeader extracts the user ID from the location,
// header once the user is created successfully
func extractUserIDFromLocationHeader(locationHeader string) (string, error) {
userURL, err := url.Parse(locationHeader)
if err != nil {
return "", err
}
return path.Base(userURL.Path), nil
}
// userData construct user data from keycloak profile.
func (kp keycloakProfile) userData() *UserData {
accountID := kp.Attributes.Get(wtAccountID)
pendingInvite, err := strconv.ParseBool(kp.Attributes.Get(wtPendingInvite))
if err != nil {
pendingInvite = false
}
return &UserData{
Email: kp.Email,
Name: kp.Username,
ID: kp.ID,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pendingInvite,
},
}
}
// Set sets the key to value. It replaces any existing
// values.
func (ka keycloakUserAttributes) Set(key, value string) {
ka[key] = []string{value}
}
// Get returns the first value associated with the given key.
// If there are no values associated with the key, Get returns
// the empty string.
func (ka keycloakUserAttributes) Get(key string) string {
if ka == nil {
return ""
}
values := ka[key]
if len(values) == 0 {
return ""
}
return values[0]
}

View File

@@ -84,15 +84,6 @@ func TestNewKeycloakManager(t *testing.T) {
}
}
type mockKeycloakCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockKeycloakCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestKeycloakRequestJWTToken(t *testing.T) {
type requestJWTTokenTest struct {
@@ -316,108 +307,3 @@ func TestKeycloakAuthenticate(t *testing.T) {
})
}
}
func TestKeycloakUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code",
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID),
appMetadata: appMetadata,
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"true\"]}}", appMetadata.WTAccountID),
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 204,
helper: JsonParser{},
managerCreds: &mockKeycloakCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &KeycloakManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}

View File

@@ -8,9 +8,9 @@ import (
"strings"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v2/okta/query"
"github.com/netbirdio/netbird/management/server/telemetry"
)
// OktaManager okta manager client instance.
@@ -76,11 +76,6 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
return nil, err
}
err = updateUserProfileSchema(client)
if err != nil {
return nil, err
}
credentials := &OktaCredentials{
clientConfig: config,
httpClient: httpClient,
@@ -103,49 +98,8 @@ func (oc *OktaCredentials) Authenticate() (JWTToken, error) {
}
// CreateUser creates a new user in okta Idp and sends an invitation.
func (om *OktaManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
var (
sendEmail = true
activate = true
userProfile = okta.UserProfile{
"email": email,
"login": email,
wtAccountID: accountID,
wtPendingInvite: true,
}
)
fields := strings.Fields(name)
if n := len(fields); n > 0 {
userProfile["firstName"] = strings.Join(fields[:n-1], " ")
userProfile["lastName"] = fields[n-1]
}
user, resp, err := om.client.User.CreateUser(context.Background(),
okta.CreateUserRequest{
Profile: &userProfile,
},
&query.Params{
Activate: &activate,
SendEmail: &sendEmail,
},
)
if err != nil {
return nil, err
}
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountCreateUser()
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
return parseOktaUser(user)
func (om *OktaManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserDataByID requests user data from keycloak via ID.
@@ -166,7 +120,13 @@ func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
}
return parseOktaUser(user)
userData, err := parseOktaUser(user)
if err != nil {
return nil, err
}
userData.AppMetadata = appMetadata
return userData, nil
}
// GetUserByEmail searches users with a given email.
@@ -200,8 +160,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
// GetAccount returns all the users for a given profile.
func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
search := fmt.Sprintf("profile.wt_account_id eq %q", accountID)
users, resp, err := om.client.User.ListUsers(context.Background(), &query.Params{Search: search})
users, resp, err := om.client.User.ListUsers(context.Background(), nil)
if err != nil {
return nil, err
}
@@ -223,6 +182,7 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
if err != nil {
return nil, err
}
userData.AppMetadata.WTAccountID = accountID
list = append(list, userData)
}
@@ -256,13 +216,8 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
return nil, err
}
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
accountID := "unset"
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
return indexedUsers, nil
@@ -270,46 +225,6 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
user, resp, err := om.client.User.GetUser(context.Background(), userID)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to update user, statusCode %d", resp.StatusCode)
}
profile := *user.Profile
if appMetadata.WTPendingInvite != nil {
profile[wtPendingInvite] = *appMetadata.WTPendingInvite
}
if appMetadata.WTAccountID != "" {
profile[wtAccountID] = appMetadata.WTAccountID
}
user.Profile = &profile
_, resp, err = om.client.User.UpdateUser(context.Background(), userID, *user, nil)
if err != nil {
fmt.Println(err.Error())
return err
}
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
if resp.StatusCode != http.StatusOK {
if om.appMetrics != nil {
om.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to update user, statusCode %d", resp.StatusCode)
}
return nil
}
@@ -341,60 +256,12 @@ func (om *OktaManager) DeleteUser(userID string) error {
return nil
}
// updateUserProfileSchema updates the Okta user schema to include custom fields,
// wt_account_id and wt_pending_invite.
func updateUserProfileSchema(client *okta.Client) error {
// Ensure Okta doesn't enforce user input for these fields, as they are solely used by Netbird
userPermissions := []*okta.UserSchemaAttributePermission{{Action: "HIDE", Principal: "SELF"}}
_, resp, err := client.UserSchema.UpdateUserProfile(
context.Background(),
"default",
okta.UserSchema{
Definitions: &okta.UserSchemaDefinitions{
Custom: &okta.UserSchemaPublic{
Id: "#custom",
Type: "object",
Properties: map[string]*okta.UserSchemaAttribute{
wtAccountID: {
MaxLength: 100,
MinLength: 1,
Required: new(bool),
Scope: "NONE",
Title: "Wt Account Id",
Type: "string",
Permissions: userPermissions,
},
wtPendingInvite: {
Required: new(bool),
Scope: "NONE",
Title: "Wt Pending Invite",
Type: "boolean",
Permissions: userPermissions,
},
},
},
},
})
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unable to update user profile schema, statusCode %d", resp.StatusCode)
}
return nil
}
// parseOktaUserToUserData parse okta user to UserData.
func parseOktaUser(user *okta.User) (*UserData, error) {
var oktaUser struct {
Email string `json:"email"`
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
AccountID string `json:"wt_account_id"`
PendingInvite bool `json:"wt_pending_invite"`
Email string `json:"email"`
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
}
if user == nil {
@@ -418,9 +285,5 @@ func parseOktaUser(user *okta.User) (*UserData, error) {
Email: oktaUser.Email,
Name: strings.Join([]string{oktaUser.FirstName, oktaUser.LastName}, " "),
ID: user.Id,
AppMetadata: AppMetadata{
WTAccountID: oktaUser.AccountID,
WTPendingInvite: &oktaUser.PendingInvite,
},
}, nil
}

View File

@@ -1,31 +1,28 @@
package idp
import (
"testing"
"github.com/okta/okta-sdk-golang/v2/okta"
"github.com/stretchr/testify/assert"
"testing"
)
func TestParseOktaUser(t *testing.T) {
type parseOktaUserTest struct {
name string
invite bool
inputProfile *okta.User
expectedUserData *UserData
assertErrFunc assert.ErrorAssertionFunc
}
parseOktaTestCase1 := parseOktaUserTest{
name: "Good Request",
invite: true,
name: "Good Request",
inputProfile: &okta.User{
Id: "123",
Profile: &okta.UserProfile{
"email": "test@example.com",
"firstName": "John",
"lastName": "Doe",
"wt_account_id": "456",
"wt_pending_invite": true,
"email": "test@example.com",
"firstName": "John",
"lastName": "Doe",
},
},
expectedUserData: &UserData{
@@ -41,36 +38,17 @@ func TestParseOktaUser(t *testing.T) {
parseOktaTestCase2 := parseOktaUserTest{
name: "Invalid okta user",
invite: true,
inputProfile: nil,
expectedUserData: nil,
assertErrFunc: assert.Error,
}
parseOktaTestCase3 := parseOktaUserTest{
name: "Invalid pending invite type",
invite: false,
inputProfile: &okta.User{
Id: "123",
Profile: &okta.UserProfile{
"email": "test@example.com",
"firstName": "John",
"lastName": "Doe",
"wt_account_id": "456",
"wt_pending_invite": "true",
},
},
expectedUserData: nil,
assertErrFunc: assert.Error,
}
for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2, parseOktaTestCase3} {
for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
userData, err := parseOktaUser(testCase.inputProfile)
testCase.assertErrFunc(t, err, testCase.assertErrFunc)
if err == nil {
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
assert.True(t, userDataEqual(testCase.expectedUserData, userData), "user data should match")
}
})
@@ -83,13 +61,5 @@ func userDataEqual(a, b *UserData) bool {
if a.Email != b.Email || a.Name != b.Name || a.ID != b.ID {
return false
}
if a.AppMetadata.WTAccountID != b.AppMetadata.WTAccountID {
return false
}
if a.AppMetadata.WTPendingInvite != nil && b.AppMetadata.WTPendingInvite != nil &&
*a.AppMetadata.WTPendingInvite != *b.AppMetadata.WTPendingInvite {
return false
}
return true
}

View File

@@ -1,13 +1,10 @@
package idp
import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
@@ -68,12 +65,6 @@ type zitadelUser struct {
type zitadelAttributes map[string][]map[string]any
// zitadelMetadata holds additional user data.
type zitadelMetadata struct {
Key string `json:"key"`
Value string `json:"value"`
}
// zitadelProfile represents an zitadel user profile response.
type zitadelProfile struct {
ID string `json:"id"`
@@ -82,7 +73,6 @@ type zitadelProfile struct {
PreferredLoginName string `json:"preferredLoginName"`
LoginNames []string `json:"loginNames"`
Human *zitadelUser `json:"human"`
Metadata []zitadelMetadata
}
// NewZitadelManager creates a new instance of the ZitadelManager.
@@ -235,42 +225,8 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
}
// CreateUser creates a new user in zitadel Idp and sends an invite.
func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
payload, err := buildZitadelCreateUserRequestPayload(email, name)
if err != nil {
return nil, err
}
body, err := zm.post("users/human/_import", payload)
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountCreateUser()
}
var result struct {
UserID string `json:"userId"`
}
err = zm.helper.Unmarshal(body, &result)
if err != nil {
return nil, err
}
invite := true
appMetadata := AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &invite,
}
// Add metadata to new user
err = zm.UpdateUserAppMetadata(result.UserID, appMetadata)
if err != nil {
return nil, err
}
return zm.GetUserDataByID(result.UserID, appMetadata)
func (zm *ZitadelManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
@@ -308,12 +264,6 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
users := make([]*UserData, 0)
for _, profile := range profiles.Result {
metadata, err := zm.getUserMetadata(profile.ID)
if err != nil {
return nil, err
}
profile.Metadata = metadata
users = append(users, profile.userData())
}
@@ -337,18 +287,15 @@ func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata
return nil, err
}
metadata, err := zm.getUserMetadata(userID)
if err != nil {
return nil, err
}
profile.User.Metadata = metadata
userData := profile.User.userData()
userData.AppMetadata = appMetadata
return profile.User.userData(), nil
return userData, nil
}
// GetAccount returns all the users for a given profile.
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
accounts, err := zm.GetAllAccounts()
body, err := zm.post("users/_search", "")
if err != nil {
return nil, err
}
@@ -357,7 +304,21 @@ func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
zm.appMetrics.IDPMetrics().CountGetAccount()
}
return accounts[accountID], nil
var profiles struct{ Result []zitadelProfile }
err = zm.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.Result {
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountID
users = append(users, userData)
}
return users, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
@@ -380,22 +341,8 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
indexedUsers := make(map[string][]*UserData)
for _, profile := range profiles.Result {
// fetch user metadata
metadata, err := zm.getUserMetadata(profile.ID)
if err != nil {
return nil, err
}
profile.Metadata = metadata
userData := profile.userData()
accountID := userData.AppMetadata.WTAccountID
if accountID != "" {
if _, ok := indexedUsers[accountID]; !ok {
indexedUsers[accountID] = make([]*UserData, 0)
}
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
}
accountID := "unset"
indexedUsers[accountID] = append(indexedUsers[accountID], profile.userData())
}
return indexedUsers, nil
@@ -403,42 +350,7 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
// Metadata values are base64 encoded.
func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
if appMetadata.WTPendingInvite == nil {
appMetadata.WTPendingInvite = new(bool)
}
pendingInviteBuf := strconv.AppendBool([]byte{}, *appMetadata.WTPendingInvite)
wtAccountIDValue := base64.StdEncoding.EncodeToString([]byte(appMetadata.WTAccountID))
wtPendingInviteValue := base64.StdEncoding.EncodeToString(pendingInviteBuf)
metadata := zitadelAttributes{
"metadata": {
{
"key": wtAccountID,
"value": wtAccountIDValue,
},
{
"key": wtPendingInvite,
"value": wtPendingInviteValue,
},
},
}
payload, err := zm.helper.Marshal(metadata)
if err != nil {
return err
}
resource := fmt.Sprintf("users/%s/metadata/_bulk", userID)
_, err = zm.post(resource, string(payload))
if err != nil {
return err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
}
func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
@@ -460,24 +372,6 @@ func (zm *ZitadelManager) DeleteUser(userID string) error {
}
return nil
}
// getUserMetadata requests user metadata from zitadel via ID.
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
body, err := zm.post(resource, "")
if err != nil {
return nil, err
}
var metadata struct{ Result []zitadelMetadata }
err = zm.helper.Unmarshal(body, &metadata)
if err != nil {
return nil, err
}
return metadata.Result, nil
}
// post perform Post requests.
@@ -517,38 +411,7 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
}
// delete perform Delete requests.
func (zm *ZitadelManager) delete(resource string) error {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := zm.httpClient.Do(req)
if err != nil {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to delete %s, statusCode %d", reqURL, resp.StatusCode)
}
func (zm *ZitadelManager) delete(_ string) error {
return nil
}
@@ -588,38 +451,13 @@ func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
return io.ReadAll(resp.Body)
}
// value returns string represented by the base64 string value.
func (zm zitadelMetadata) value() string {
value, err := base64.StdEncoding.DecodeString(zm.Value)
if err != nil {
return ""
}
return string(value)
}
// userData construct user data from zitadel profile.
func (zp zitadelProfile) userData() *UserData {
var (
email string
name string
wtAccountIDValue string
wtPendingInviteValue bool
email string
name string
)
for _, metadata := range zp.Metadata {
if metadata.Key == wtAccountID {
wtAccountIDValue = metadata.value()
}
if metadata.Key == wtPendingInvite {
value, err := strconv.ParseBool(metadata.value())
if err == nil {
wtPendingInviteValue = value
}
}
}
// Obtain the email for the human account and the login name,
// for the machine account.
if zp.Human != nil {
@@ -636,39 +474,5 @@ func (zp zitadelProfile) userData() *UserData {
Email: email,
Name: name,
ID: zp.ID,
AppMetadata: AppMetadata{
WTAccountID: wtAccountIDValue,
WTPendingInvite: &wtPendingInviteValue,
},
}
}
func buildZitadelCreateUserRequestPayload(email string, name string) (string, error) {
var firstName, lastName string
words := strings.Fields(name)
if n := len(words); n > 0 {
firstName = strings.Join(words[:n-1], " ")
lastName = words[n-1]
}
req := &zitadelUser{
UserName: name,
Profile: zitadelUserInfo{
FirstName: strings.TrimSpace(firstName),
LastName: strings.TrimSpace(lastName),
DisplayName: name,
},
Email: zitadelEmail{
Email: email,
IsEmailVerified: false,
},
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}

View File

@@ -7,9 +7,10 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func TestNewZitadelManager(t *testing.T) {
@@ -63,15 +64,6 @@ func TestNewZitadelManager(t *testing.T) {
}
}
type mockZitadelCredentials struct {
jwtToken JWTToken
err error
}
func (mc *mockZitadelCredentials) Authenticate() (JWTToken, error) {
return mc.jwtToken, mc.err
}
func TestZitadelRequestJWTToken(t *testing.T) {
type requestJWTTokenTest struct {
@@ -296,98 +288,6 @@ func TestZitadelAuthenticate(t *testing.T) {
}
}
func TestZitadelUpdateUserAppMetadata(t *testing.T) {
type updateUserAppMetadataTest struct {
name string
inputReqBody string
expectedReqBody string
appMetadata AppMetadata
statusCode int
helper ManagerHelper
managerCreds ManagerCredentials
assertErrFunc assert.ErrorAssertionFunc
assertErrFuncMessage string
}
appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication",
expectedReqBody: "",
appMetadata: appMetadata,
statusCode: 400,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
err: fmt.Errorf("error"),
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Response Parsing",
statusCode: 400,
helper: &mockJsonParser{marshalErrorString: "error"},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.Error,
assertErrFuncMessage: "should return error",
}
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
name: "Good request",
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"ZmFsc2U=\"}]}",
appMetadata: appMetadata,
statusCode: 200,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
invite := true
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Update Pending Invite",
expectedReqBody: "{\"metadata\":[{\"key\":\"wt_account_id\",\"value\":\"b2s=\"},{\"key\":\"wt_pending_invite\",\"value\":\"dHJ1ZQ==\"}]}",
appMetadata: AppMetadata{
WTAccountID: "ok",
WTPendingInvite: &invite,
},
statusCode: 200,
helper: JsonParser{},
managerCreds: &mockZitadelCredentials{
jwtToken: JWTToken{},
},
assertErrFunc: assert.NoError,
assertErrFuncMessage: "shouldn't return error",
}
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4} {
t.Run(testCase.name, func(t *testing.T) {
reqClient := mockHTTPClient{
resBody: testCase.inputReqBody,
code: testCase.statusCode,
}
manager := &ZitadelManager{
httpClient: &reqClient,
credentials: testCase.managerCreds,
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
})
}
}
func TestZitadelProfile(t *testing.T) {
type azureProfileTest struct {
name string
@@ -418,16 +318,6 @@ func TestZitadelProfile(t *testing.T) {
IsEmailVerified: true,
},
},
Metadata: []zitadelMetadata{
{
Key: "wt_account_id",
Value: "MQ==",
},
{
Key: "wt_pending_invite",
Value: "ZmFsc2U=",
},
},
},
expectedUserData: UserData{
ID: "test1",
@@ -451,16 +341,6 @@ func TestZitadelProfile(t *testing.T) {
"machine",
},
Human: nil,
Metadata: []zitadelMetadata{
{
Key: "wt_account_id",
Value: "MQ==",
},
{
Key: "wt_pending_invite",
Value: "dHJ1ZQ==",
},
},
},
expectedUserData: UserData{
ID: "test2",
@@ -480,8 +360,6 @@ func TestZitadelProfile(t *testing.T) {
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
})
}
}