mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
Compare commits
4 Commits
netmap
...
feat-group
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4246138e41 | ||
|
|
590d977dc7 | ||
|
|
12f28e9aa4 | ||
|
|
e41072e2fc |
@@ -34,6 +34,8 @@ const (
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
GroupIssuedAPI = "api"
|
||||
GroupIssuedJWT = "jwt"
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
@@ -139,6 +141,13 @@ type Settings struct {
|
||||
// PeerLoginExpiration is a setting that indicates when peer login expires.
|
||||
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
||||
PeerLoginExpiration time.Duration
|
||||
|
||||
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
|
||||
// and add it to account groups.
|
||||
JWTGroupsEnabled bool
|
||||
|
||||
// JWTGroupsClaimName from which we extract groups name to add it to account groups
|
||||
JWTGroupsClaimName string
|
||||
}
|
||||
|
||||
// Copy copies the Settings struct
|
||||
@@ -146,6 +155,8 @@ func (s *Settings) Copy() *Settings {
|
||||
return &Settings{
|
||||
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: s.PeerLoginExpiration,
|
||||
JWTGroupsEnabled: s.JWTGroupsEnabled,
|
||||
JWTGroupsClaimName: s.JWTGroupsClaimName,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -612,6 +623,28 @@ func (a *Account) GetPeer(peerID string) *Peer {
|
||||
return a.Peers[peerID]
|
||||
}
|
||||
|
||||
// AddJWTGroups to existed groups if they does not exists
|
||||
func (a *Account) AddJWTGroups(groups []string) (int, error) {
|
||||
existedGroups := make(map[string]*Group)
|
||||
for _, g := range a.Groups {
|
||||
existedGroups[g.Name] = g
|
||||
}
|
||||
|
||||
var count int
|
||||
for _, name := range groups {
|
||||
if _, ok := existedGroups[name]; !ok {
|
||||
id := xid.New().String()
|
||||
a.Groups[id] = &Group{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Issued: GroupIssuedJWT,
|
||||
}
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
|
||||
@@ -1241,6 +1274,38 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
}
|
||||
}
|
||||
|
||||
if account.Settings.JWTGroupsEnabled {
|
||||
if account.Settings.JWTGroupsClaimName == "" {
|
||||
log.Errorf("JWT groups are enabled but no claim name is set")
|
||||
return account, user, nil
|
||||
}
|
||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||
if slice, ok := claim.([]interface{}); ok {
|
||||
var groups []string
|
||||
for _, item := range slice {
|
||||
if g, ok := item.(string); ok {
|
||||
groups = append(groups, g)
|
||||
} else {
|
||||
log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
|
||||
}
|
||||
}
|
||||
n, err := account.AddJWTGroups(groups)
|
||||
if err != nil {
|
||||
log.Errorf("failed to add JWT groups: %v", err)
|
||||
}
|
||||
if n > 0 {
|
||||
if err := am.Store.SaveAccount(account); err != nil {
|
||||
log.Errorf("failed to save account: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
|
||||
}
|
||||
}
|
||||
|
||||
return account, user, nil
|
||||
}
|
||||
|
||||
@@ -1344,8 +1409,9 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
||||
func addAllGroup(account *Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
allGroup := &Group{
|
||||
ID: xid.New().String(),
|
||||
Name: "All",
|
||||
ID: xid.New().String(),
|
||||
Name: "All",
|
||||
Issued: GroupIssuedAPI,
|
||||
}
|
||||
for _, peer := range account.Peers {
|
||||
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -460,6 +461,69 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
|
||||
initAccount := newAccountWithId("", userId, domain)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID := initAccount.Id
|
||||
_, err = manager.GetAccountByUserOrAccountID(userId, accountID, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
AccountId: accountID,
|
||||
Domain: domain,
|
||||
UserId: userId,
|
||||
DomainCategory: "test-category",
|
||||
Raw: jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}},
|
||||
}
|
||||
|
||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
err := manager.Store.SaveAccount(initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
err := manager.Store.SaveAccount(initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||
|
||||
groupsByNames := map[string]*Group{}
|
||||
for _, g := range account.Groups {
|
||||
groupsByNames[g.Name] = g
|
||||
}
|
||||
|
||||
g1, ok := groupsByNames["group1"]
|
||||
require.True(t, ok, "group1 should be added to the account")
|
||||
require.Equal(t, g1.Name, "group1", "group1 name should match")
|
||||
require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match")
|
||||
|
||||
g2, ok := groupsByNames["group2"]
|
||||
require.True(t, ok, "group2 should be added to the account")
|
||||
require.Equal(t, g2.Name, "group2", "group2 name should match")
|
||||
require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
|
||||
@@ -157,6 +157,14 @@ func restore(file string) (*FileStore, error) {
|
||||
addPeerLabelsToAccount(account, existingLabels)
|
||||
}
|
||||
|
||||
// TODO: delete this block after migration
|
||||
// Set API as issuer for groups which has not this field
|
||||
for _, group := range account.Groups {
|
||||
if group.Issued == "" {
|
||||
group.Issued = GroupIssuedAPI
|
||||
}
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
||||
|
||||
@@ -262,6 +262,7 @@ func TestRestore(t *testing.T) {
|
||||
require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
|
||||
}
|
||||
|
||||
// TODO: outdated, delete this
|
||||
func TestRestorePolicies_Migration(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
|
||||
@@ -296,6 +297,40 @@ func TestRestorePolicies_Migration(t *testing.T) {
|
||||
"failed to restore a FileStore file - missing Account Policies Sources")
|
||||
}
|
||||
|
||||
func TestRestoreGroups_Migration(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
|
||||
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// create default group
|
||||
account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||
account.Groups = map[string]*Group{
|
||||
"cfefqs706sqkneg59g3g": {
|
||||
ID: "cfefqs706sqkneg59g3g",
|
||||
Name: "All",
|
||||
},
|
||||
}
|
||||
err = store.SaveAccount(account)
|
||||
require.NoError(t, err, "failed to save account")
|
||||
|
||||
// restore account with default group with empty Issue field
|
||||
if store, err = NewFileStore(storeDir, nil); err != nil {
|
||||
return
|
||||
}
|
||||
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||
|
||||
require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups")
|
||||
require.Equal(t, GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark")
|
||||
}
|
||||
|
||||
func TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ type Group struct {
|
||||
// Name visible in the UI
|
||||
Name string
|
||||
|
||||
// Issued of the group
|
||||
Issued string
|
||||
|
||||
// Peers list of the group
|
||||
Peers []string
|
||||
}
|
||||
@@ -45,9 +48,10 @@ func (g *Group) EventMeta() map[string]any {
|
||||
|
||||
func (g *Group) Copy() *Group {
|
||||
return &Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Peers: g.Peers[:],
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Issued: g.Issued,
|
||||
Peers: g.Peers[:],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
package jwtclaims
|
||||
|
||||
import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// AuthorizationClaims stores authorization information from JWTs
|
||||
type AuthorizationClaims struct {
|
||||
UserId string
|
||||
AccountId string
|
||||
Domain string
|
||||
DomainCategory string
|
||||
|
||||
Raw jwt.MapClaims
|
||||
}
|
||||
|
||||
@@ -73,7 +73,9 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
|
||||
// FromToken extracts claims from the token (after auth)
|
||||
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
jwtClaims := AuthorizationClaims{}
|
||||
jwtClaims := AuthorizationClaims{
|
||||
Raw: claims,
|
||||
}
|
||||
userID, ok := claims[c.userIDClaim].(string)
|
||||
if !ok {
|
||||
return jwtClaims
|
||||
|
||||
@@ -48,6 +48,12 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
Domain: "test.com",
|
||||
AccountId: "testAcc",
|
||||
DomainCategory: "public",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"https://login/wt_account_domain_category": "public",
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@@ -59,6 +65,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
AccountId: "testAcc",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@@ -70,6 +80,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@@ -82,6 +96,11 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
AccountId: "testAcc",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@@ -92,6 +111,9 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Raw: jwt.MapClaims{
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
|
||||
Reference in New Issue
Block a user