Compare commits

...

3 Commits

Author SHA1 Message Date
braginini
8d3e5f508c Fix nil jwt nil
Entire-Checkpoint: cfd28dfcf51a
2026-03-31 17:36:08 +02:00
braginini
8d09ded1db Fix go.mod 2026-03-30 17:28:57 +02:00
braginini
a49a052f05 Fetch signing keys directly from the embedded IdP
Entire-Checkpoint: 5eaefec1fa77
2026-03-30 17:25:07 +02:00
8 changed files with 127 additions and 26 deletions

2
go.mod
View File

@@ -49,6 +49,7 @@ require (
github.com/eko/gocache/store/redis/v4 v4.2.2
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.3
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/golang/mock v1.6.0
@@ -181,7 +182,6 @@ require (
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/go-ldap/ldap/v3 v3.4.12 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect

View File

@@ -4,6 +4,7 @@ package dex
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
@@ -19,10 +20,13 @@ import (
"github.com/dexidp/dex/server"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
jose "github.com/go-jose/go-jose/v4"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
// Config matches what management/internals/server/server.go expects
@@ -666,3 +670,46 @@ func (p *Provider) GetAuthorizationEndpoint() string {
}
return issuer + "/auth"
}
// GetJWKS reads signing keys directly from Dex storage and returns them as Jwks.
// This avoids HTTP round-trips when the embedded IDP is co-located with the management server.
// The key retrieval mirrors Dex's own handlePublicKeys/ValidationKeys logic:
// SigningKeyPub first, then all VerificationKeys, serialized via go-jose.
func (p *Provider) GetJWKS(ctx context.Context) (*nbjwt.Jwks, error) {
keys, err := p.storage.GetKeys(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get keys from storage: %w", err)
}
if keys.SigningKeyPub == nil {
return nil, fmt.Errorf("no public keys found in storage")
}
// Build the key set exactly as Dex's localSigner.ValidationKeys does:
// signing key first, then all verification (rotated) keys.
joseKeys := make([]jose.JSONWebKey, 0, len(keys.VerificationKeys)+1)
joseKeys = append(joseKeys, *keys.SigningKeyPub)
for _, vk := range keys.VerificationKeys {
if vk.PublicKey != nil {
joseKeys = append(joseKeys, *vk.PublicKey)
}
}
// Serialize through go-jose (same as Dex's handlePublicKeys handler)
// then deserialize into our Jwks type, so the JSON field mapping is identical
// to what the /keys HTTP endpoint would return.
joseSet := jose.JSONWebKeySet{Keys: joseKeys}
data, err := json.Marshal(joseSet)
if err != nil {
return nil, fmt.Errorf("failed to marshal JWKS: %w", err)
}
jwks := &nbjwt.Jwks{}
if err := json.Unmarshal(data, jwks); err != nil {
return nil, fmt.Errorf("failed to unmarshal JWKS: %w", err)
}
jwks.ExpiresInTime = keys.NextRotation
return jwks, nil
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
@@ -71,6 +72,7 @@ func (s *BaseServer) AuthManager() auth.Manager {
signingKeyRefreshEnabled := s.Config.HttpConfig.IdpSignKeyRefreshEnabled
issuer := s.Config.HttpConfig.AuthIssuer
userIDClaim := s.Config.HttpConfig.AuthUserIDClaim
var keyFetcher nbjwt.KeyFetcher
// Use embedded IdP configuration if available
if oauthProvider := s.OAuthConfigProvider(); oauthProvider != nil {
@@ -78,8 +80,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
if len(audiences) > 0 {
audience = audiences[0] // Use the first client ID as the primary audience
}
// Use localhost keys location for internal validation (management has embedded Dex)
keysLocation = oauthProvider.GetLocalKeysLocation()
keyFetcher = oauthProvider.GetKeyFetcher()
// Fall back to default keys location if direct key fetching is not available
if keyFetcher == nil {
keysLocation = oauthProvider.GetLocalKeysLocation()
}
signingKeyRefreshEnabled = true
issuer = oauthProvider.GetIssuer()
userIDClaim = oauthProvider.GetUserIDClaim()
@@ -92,7 +97,8 @@ func (s *BaseServer) AuthManager() auth.Manager {
keysLocation,
userIDClaim,
audiences,
signingKeyRefreshEnabled)
signingKeyRefreshEnabled,
keyFetcher)
})
}

View File

@@ -33,15 +33,20 @@ type manager struct {
extractor *nbjwt.ClaimsExtractor
}
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager {
// @note if invalid/missing parameters are sent the validator will instantiate
// but it will fail when validating and parsing the token
jwtValidator := nbjwt.NewValidator(
issuer,
allAudiences,
keysLocation,
idpRefreshKeys,
)
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool, keyFetcher nbjwt.KeyFetcher) Manager {
var jwtValidator *nbjwt.Validator
if keyFetcher != nil {
jwtValidator = nbjwt.NewValidatorWithKeyFetcher(issuer, allAudiences, keyFetcher)
} else {
// @note if invalid/missing parameters are sent the validator will instantiate
// but it will fail when validating and parsing the token
jwtValidator = nbjwt.NewValidator(
issuer,
allAudiences,
keysLocation,
idpRefreshKeys,
)
}
claimsExtractor := nbjwt.NewClaimsExtractor(
nbjwt.WithAudience(audience),

View File

@@ -52,7 +52,7 @@ func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
if err != nil {
@@ -92,7 +92,7 @@ func TestAuthManager_MarkPATUsed(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
err = manager.MarkPATUsed(context.Background(), "tokenId")
if err != nil {
@@ -142,7 +142,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
// these tests only assert groups are parsed from token as per account settings
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
t.Run("JWT groups disabled", func(t *testing.T) {
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
@@ -225,7 +225,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
keyId := "test-key"
// note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false)
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false, nil)
customClaim := func(name string) string {
return fmt.Sprintf("%s/%s", audience, name)

View File

@@ -119,7 +119,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil)
authManagerMock := &serverauth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
@@ -248,7 +248,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil)
authManagerMock := &serverauth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/telemetry"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
const (
@@ -193,6 +194,9 @@ type OAuthConfigProvider interface {
// Management server has embedded Dex and can validate tokens via localhost,
// avoiding external network calls and DNS resolution issues during startup.
GetLocalKeysLocation() string
// GetKeyFetcher returns a KeyFetcher that reads keys directly from the IDP storage,
// or nil if direct key fetching is not supported (falls back to HTTP).
GetKeyFetcher() nbjwt.KeyFetcher
GetClientIDs() []string
GetUserIDClaim() string
GetTokenEndpoint() string
@@ -593,6 +597,11 @@ func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string {
return m.config.CLIRedirectURIs
}
// GetKeyFetcher returns a KeyFetcher that reads keys directly from Dex storage.
func (m *EmbeddedIdPManager) GetKeyFetcher() nbjwt.KeyFetcher {
return m.provider.GetJWKS
}
// GetKeysLocation returns the JWKS endpoint URL for token validation.
func (m *EmbeddedIdPManager) GetKeysLocation() string {
return m.provider.GetKeysLocation()

View File

@@ -25,7 +25,7 @@ import (
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct {
Keys []JSONWebKey `json:"keys"`
expiresInTime time.Time
ExpiresInTime time.Time `json:"-"`
}
// The supported elliptic curves types
@@ -53,12 +53,17 @@ type JSONWebKey struct {
X5c []string `json:"x5c"`
}
// KeyFetcher is a function that retrieves JWKS keys directly (e.g., from Dex storage)
// bypassing HTTP. When set on a Validator, it is used instead of the HTTP-based getPemKeys.
type KeyFetcher func(ctx context.Context) (*Jwks, error)
type Validator struct {
lock sync.Mutex
issuer string
audienceList []string
keysLocation string
idpSignkeyRefreshEnabled bool
keyFetcher KeyFetcher
keys *Jwks
lastForcedRefresh time.Time
}
@@ -85,10 +90,39 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp
}
}
// NewValidatorWithKeyFetcher creates a Validator that fetches keys directly using the
// provided KeyFetcher (e.g., from Dex storage) instead of via HTTP.
func NewValidatorWithKeyFetcher(issuer string, audienceList []string, keyFetcher KeyFetcher) *Validator {
ctx := context.Background()
keys, err := keyFetcher(ctx)
if err != nil {
log.Warnf("could not get keys from key fetcher: %s, it will try again on the next http request", err)
}
if keys == nil {
keys = &Jwks{}
}
return &Validator{
keys: keys,
issuer: issuer,
audienceList: audienceList,
idpSignkeyRefreshEnabled: true,
keyFetcher: keyFetcher,
}
}
// forcedRefreshCooldown is the minimum time between forced key refreshes
// to prevent abuse from invalid tokens with fake kid values
const forcedRefreshCooldown = 30 * time.Second
// fetchKeys retrieves keys using the keyFetcher if available, otherwise falls back to HTTP.
func (v *Validator) fetchKeys(ctx context.Context) (*Jwks, error) {
if v.keyFetcher != nil {
return v.keyFetcher(ctx)
}
return getPemKeys(v.keysLocation)
}
func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
return func(token *jwt.Token) (interface{}, error) {
// If keys are rotated, verify the keys prior to token validation
@@ -131,13 +165,13 @@ func (v *Validator) refreshKeys(ctx context.Context) {
v.lock.Lock()
defer v.lock.Unlock()
refreshedKeys, err := getPemKeys(v.keysLocation)
refreshedKeys, err := v.fetchKeys(ctx)
if err != nil {
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
return
}
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC())
v.keys = refreshedKeys
}
@@ -155,13 +189,13 @@ func (v *Validator) forceRefreshKeys(ctx context.Context) bool {
log.WithContext(ctx).Debugf("key not found in cache, forcing JWKS refresh")
refreshedKeys, err := getPemKeys(v.keysLocation)
refreshedKeys, err := v.fetchKeys(ctx)
if err != nil {
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
return false
}
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC())
v.keys = refreshedKeys
v.lastForcedRefresh = time.Now()
return true
@@ -203,7 +237,7 @@ func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.To
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
func (jwks *Jwks) stillValid() bool {
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
return !jwks.ExpiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.ExpiresInTime)
}
func getPemKeys(keysLocation string) (*Jwks, error) {
@@ -227,7 +261,7 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
cacheControlHeader := resp.Header.Get("Cache-Control")
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
jwks.ExpiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
return jwks, nil
}