Compare commits

...

4 Commits

Author SHA1 Message Date
Givi Khojanashvili
4246138e41 Fix review note 2023-06-19 18:33:34 +04:00
Givi Khojanashvili
590d977dc7 Fix claim extraciton after testing on real auth0 setup 2023-06-19 18:23:26 +04:00
Givi Khojanashvili
12f28e9aa4 Fix migration fro groups issue 2023-06-19 11:45:08 +04:00
Givi Khojanashvili
e41072e2fc Get groups from the JWT tokens if it feature enabled for account 2023-06-19 10:03:34 +04:00
8 changed files with 213 additions and 6 deletions

View File

@@ -34,6 +34,8 @@ const (
PublicCategory = "public"
PrivateCategory = "private"
UnknownCategory = "unknown"
GroupIssuedAPI = "api"
GroupIssuedJWT = "jwt"
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
DefaultPeerLoginExpiration = 24 * time.Hour
@@ -139,6 +141,13 @@ type Settings struct {
// PeerLoginExpiration is a setting that indicates when peer login expires.
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
PeerLoginExpiration time.Duration
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
// and add it to account groups.
JWTGroupsEnabled bool
// JWTGroupsClaimName from which we extract groups name to add it to account groups
JWTGroupsClaimName string
}
// Copy copies the Settings struct
@@ -146,6 +155,8 @@ func (s *Settings) Copy() *Settings {
return &Settings{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpiration: s.PeerLoginExpiration,
JWTGroupsEnabled: s.JWTGroupsEnabled,
JWTGroupsClaimName: s.JWTGroupsClaimName,
}
}
@@ -612,6 +623,28 @@ func (a *Account) GetPeer(peerID string) *Peer {
return a.Peers[peerID]
}
// AddJWTGroups to existed groups if they does not exists
func (a *Account) AddJWTGroups(groups []string) (int, error) {
existedGroups := make(map[string]*Group)
for _, g := range a.Groups {
existedGroups[g.Name] = g
}
var count int
for _, name := range groups {
if _, ok := existedGroups[name]; !ok {
id := xid.New().String()
a.Groups[id] = &Group{
ID: id,
Name: name,
Issued: GroupIssuedJWT,
}
count++
}
}
return count, nil
}
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
@@ -1241,6 +1274,38 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
}
}
if account.Settings.JWTGroupsEnabled {
if account.Settings.JWTGroupsClaimName == "" {
log.Errorf("JWT groups are enabled but no claim name is set")
return account, user, nil
}
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if slice, ok := claim.([]interface{}); ok {
var groups []string
for _, item := range slice {
if g, ok := item.(string); ok {
groups = append(groups, g)
} else {
log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
}
}
n, err := account.AddJWTGroups(groups)
if err != nil {
log.Errorf("failed to add JWT groups: %v", err)
}
if n > 0 {
if err := am.Store.SaveAccount(account); err != nil {
log.Errorf("failed to save account: %v", err)
}
}
} else {
log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
}
} else {
log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
}
}
return account, user, nil
}
@@ -1344,8 +1409,9 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
allGroup := &Group{
ID: xid.New().String(),
Name: "All",
ID: xid.New().String(),
Name: "All",
Issued: GroupIssuedAPI,
}
for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.ID)

View File

@@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/route"
@@ -460,6 +461,69 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
}
}
func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id"
domain := "test.domain"
initAccount := newAccountWithId("", userId, domain)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id
_, err = manager.GetAccountByUserOrAccountID(userId, accountID, domain)
require.NoError(t, err, "create init user failed")
claims := jwtclaims.AuthorizationClaims{
AccountId: accountID,
Domain: domain,
UserId: userId,
DomainCategory: "test-category",
Raw: jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}},
}
t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
err := manager.Store.SaveAccount(initAccount)
require.NoError(t, err, "save account failed")
account, _, err := manager.GetAccountFromToken(claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(initAccount)
require.NoError(t, err, "save account failed")
account, _, err := manager.GetAccountFromToken(claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 3, "groups should be added to the account")
groupsByNames := map[string]*Group{}
for _, g := range account.Groups {
groupsByNames[g.Name] = g
}
g1, ok := groupsByNames["group1"]
require.True(t, ok, "group1 should be added to the account")
require.Equal(t, g1.Name, "group1", "group1 name should match")
require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match")
g2, ok := groupsByNames["group2"]
require.True(t, ok, "group2 should be added to the account")
require.Equal(t, g2.Name, "group2", "group2 name should match")
require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match")
})
}
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store := newStore(t)
account := newAccountWithId("account_id", "testuser", "")

View File

@@ -157,6 +157,14 @@ func restore(file string) (*FileStore, error) {
addPeerLabelsToAccount(account, existingLabels)
}
// TODO: delete this block after migration
// Set API as issuer for groups which has not this field
for _, group := range account.Groups {
if group.Issued == "" {
group.Issued = GroupIssuedAPI
}
}
allGroup, err := account.GetGroupAll()
if err != nil {
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)

View File

@@ -262,6 +262,7 @@ func TestRestore(t *testing.T) {
require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
}
// TODO: outdated, delete this
func TestRestorePolicies_Migration(t *testing.T) {
storeDir := t.TempDir()
@@ -296,6 +297,40 @@ func TestRestorePolicies_Migration(t *testing.T) {
"failed to restore a FileStore file - missing Account Policies Sources")
}
func TestRestoreGroups_Migration(t *testing.T) {
storeDir := t.TempDir()
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
if err != nil {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
if err != nil {
return
}
// create default group
account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
account.Groups = map[string]*Group{
"cfefqs706sqkneg59g3g": {
ID: "cfefqs706sqkneg59g3g",
Name: "All",
},
}
err = store.SaveAccount(account)
require.NoError(t, err, "failed to save account")
// restore account with default group with empty Issue field
if store, err = NewFileStore(storeDir, nil); err != nil {
return
}
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups")
require.Equal(t, GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark")
}
func TestGetAccountByPrivateDomain(t *testing.T) {
storeDir := t.TempDir()

View File

@@ -14,6 +14,9 @@ type Group struct {
// Name visible in the UI
Name string
// Issued of the group
Issued string
// Peers list of the group
Peers []string
}
@@ -45,9 +48,10 @@ func (g *Group) EventMeta() map[string]any {
func (g *Group) Copy() *Group {
return &Group{
ID: g.ID,
Name: g.Name,
Peers: g.Peers[:],
ID: g.ID,
Name: g.Name,
Issued: g.Issued,
Peers: g.Peers[:],
}
}

View File

@@ -1,9 +1,15 @@
package jwtclaims
import (
"github.com/golang-jwt/jwt"
)
// AuthorizationClaims stores authorization information from JWTs
type AuthorizationClaims struct {
UserId string
AccountId string
Domain string
DomainCategory string
Raw jwt.MapClaims
}

View File

@@ -73,7 +73,9 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
// FromToken extracts claims from the token (after auth)
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
claims := token.Claims.(jwt.MapClaims)
jwtClaims := AuthorizationClaims{}
jwtClaims := AuthorizationClaims{
Raw: claims,
}
userID, ok := claims[c.userIDClaim].(string)
if !ok {
return jwtClaims

View File

@@ -48,6 +48,12 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
Domain: "test.com",
AccountId: "testAcc",
DomainCategory: "public",
Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com",
"https://login/wt_account_domain_category": "public",
"https://login/wt_account_id": "testAcc",
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
@@ -59,6 +65,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
AccountId: "testAcc",
Raw: jwt.MapClaims{
"https://login/wt_account_id": "testAcc",
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
@@ -70,6 +80,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Domain: "test.com",
Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com",
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
@@ -82,6 +96,11 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
UserId: "test",
Domain: "test.com",
AccountId: "testAcc",
Raw: jwt.MapClaims{
"https://login/wt_account_domain": "test.com",
"https://login/wt_account_id": "testAcc",
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",
@@ -92,6 +111,9 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
inputAudiance: "https://login/",
inputAuthorizationClaims: AuthorizationClaims{
UserId: "test",
Raw: jwt.MapClaims{
"sub": "test",
},
},
testingFunc: require.EqualValues,
expectedMSG: "extracted claims should match input claims",