mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 00:54:01 -04:00
Feature/resolve local jwks keys (#5073)
This commit is contained in:
113
idp/dex/logrus_handler.go
Normal file
113
idp/dex/logrus_handler.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package dex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/formatter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogrusHandler is an slog.Handler that delegates to logrus.
|
||||||
|
// This allows Dex to use the same log format as the rest of NetBird.
|
||||||
|
type LogrusHandler struct {
|
||||||
|
logger *logrus.Logger
|
||||||
|
attrs []slog.Attr
|
||||||
|
groups []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLogrusHandler creates a new slog handler that wraps logrus with NetBird's text formatter.
|
||||||
|
func NewLogrusHandler(level slog.Level) *LogrusHandler {
|
||||||
|
logger := logrus.New()
|
||||||
|
formatter.SetTextFormatter(logger)
|
||||||
|
|
||||||
|
// Map slog level to logrus level
|
||||||
|
switch level {
|
||||||
|
case slog.LevelDebug:
|
||||||
|
logger.SetLevel(logrus.DebugLevel)
|
||||||
|
case slog.LevelInfo:
|
||||||
|
logger.SetLevel(logrus.InfoLevel)
|
||||||
|
case slog.LevelWarn:
|
||||||
|
logger.SetLevel(logrus.WarnLevel)
|
||||||
|
case slog.LevelError:
|
||||||
|
logger.SetLevel(logrus.ErrorLevel)
|
||||||
|
default:
|
||||||
|
logger.SetLevel(logrus.WarnLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &LogrusHandler{logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enabled reports whether the handler handles records at the given level.
|
||||||
|
func (h *LogrusHandler) Enabled(_ context.Context, level slog.Level) bool {
|
||||||
|
switch level {
|
||||||
|
case slog.LevelDebug:
|
||||||
|
return h.logger.IsLevelEnabled(logrus.DebugLevel)
|
||||||
|
case slog.LevelInfo:
|
||||||
|
return h.logger.IsLevelEnabled(logrus.InfoLevel)
|
||||||
|
case slog.LevelWarn:
|
||||||
|
return h.logger.IsLevelEnabled(logrus.WarnLevel)
|
||||||
|
case slog.LevelError:
|
||||||
|
return h.logger.IsLevelEnabled(logrus.ErrorLevel)
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle handles the Record.
|
||||||
|
func (h *LogrusHandler) Handle(_ context.Context, r slog.Record) error {
|
||||||
|
fields := make(logrus.Fields)
|
||||||
|
|
||||||
|
// Add pre-set attributes
|
||||||
|
for _, attr := range h.attrs {
|
||||||
|
fields[attr.Key] = attr.Value.Any()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add record attributes
|
||||||
|
r.Attrs(func(attr slog.Attr) bool {
|
||||||
|
fields[attr.Key] = attr.Value.Any()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
entry := h.logger.WithFields(fields)
|
||||||
|
|
||||||
|
switch r.Level {
|
||||||
|
case slog.LevelDebug:
|
||||||
|
entry.Debug(r.Message)
|
||||||
|
case slog.LevelInfo:
|
||||||
|
entry.Info(r.Message)
|
||||||
|
case slog.LevelWarn:
|
||||||
|
entry.Warn(r.Message)
|
||||||
|
case slog.LevelError:
|
||||||
|
entry.Error(r.Message)
|
||||||
|
default:
|
||||||
|
entry.Info(r.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAttrs returns a new Handler with the given attributes added.
|
||||||
|
func (h *LogrusHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||||
|
newAttrs := make([]slog.Attr, len(h.attrs)+len(attrs))
|
||||||
|
copy(newAttrs, h.attrs)
|
||||||
|
copy(newAttrs[len(h.attrs):], attrs)
|
||||||
|
return &LogrusHandler{
|
||||||
|
logger: h.logger,
|
||||||
|
attrs: newAttrs,
|
||||||
|
groups: h.groups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithGroup returns a new Handler with the given group appended to the receiver's groups.
|
||||||
|
func (h *LogrusHandler) WithGroup(name string) slog.Handler {
|
||||||
|
newGroups := make([]string, len(h.groups)+1)
|
||||||
|
copy(newGroups, h.groups)
|
||||||
|
newGroups[len(h.groups)] = name
|
||||||
|
return &LogrusHandler{
|
||||||
|
logger: h.logger,
|
||||||
|
attrs: h.attrs,
|
||||||
|
groups: newGroups,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -130,7 +130,21 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
|||||||
|
|
||||||
// NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig
|
// NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig
|
||||||
func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) {
|
func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) {
|
||||||
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
// Configure log level from config, default to WARN to avoid logging sensitive data (emails)
|
||||||
|
logLevel := slog.LevelWarn
|
||||||
|
if yamlConfig.Logger.Level != "" {
|
||||||
|
switch strings.ToLower(yamlConfig.Logger.Level) {
|
||||||
|
case "debug":
|
||||||
|
logLevel = slog.LevelDebug
|
||||||
|
case "info":
|
||||||
|
logLevel = slog.LevelInfo
|
||||||
|
case "warn", "warning":
|
||||||
|
logLevel = slog.LevelWarn
|
||||||
|
case "error":
|
||||||
|
logLevel = slog.LevelError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger := slog.New(NewLogrusHandler(logLevel))
|
||||||
|
|
||||||
stor, err := yamlConfig.Storage.OpenStorage(logger)
|
stor, err := yamlConfig.Storage.OpenStorage(logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -190,6 +190,9 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
|
|||||||
// Enable user deletion from IDP by default if EmbeddedIdP is enabled
|
// Enable user deletion from IDP by default if EmbeddedIdP is enabled
|
||||||
userDeleteFromIDPEnabled = true
|
userDeleteFromIDPEnabled = true
|
||||||
|
|
||||||
|
// Set LocalAddress for embedded IdP if enabled, used for internal JWT validation
|
||||||
|
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
|
||||||
|
|
||||||
// Ensure HttpConfig exists
|
// Ensure HttpConfig exists
|
||||||
if cfg.HttpConfig == nil {
|
if cfg.HttpConfig == nil {
|
||||||
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
|
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
|
||||||
|
|||||||
@@ -68,7 +68,8 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
|||||||
if len(audiences) > 0 {
|
if len(audiences) > 0 {
|
||||||
audience = audiences[0] // Use the first client ID as the primary audience
|
audience = audiences[0] // Use the first client ID as the primary audience
|
||||||
}
|
}
|
||||||
keysLocation = oauthProvider.GetKeysLocation()
|
// Use localhost keys location for internal validation (management has embedded Dex)
|
||||||
|
keysLocation = oauthProvider.GetLocalKeysLocation()
|
||||||
signingKeyRefreshEnabled = true
|
signingKeyRefreshEnabled = true
|
||||||
issuer = oauthProvider.GetIssuer()
|
issuer = oauthProvider.GetIssuer()
|
||||||
userIDClaim = oauthProvider.GetUserIDClaim()
|
userIDClaim = oauthProvider.GetUserIDClaim()
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -27,8 +28,11 @@ const (
|
|||||||
type EmbeddedIdPConfig struct {
|
type EmbeddedIdPConfig struct {
|
||||||
// Enabled indicates whether the embedded IDP is enabled
|
// Enabled indicates whether the embedded IDP is enabled
|
||||||
Enabled bool
|
Enabled bool
|
||||||
// Issuer is the OIDC issuer URL (e.g., "http://localhost:3002/oauth2")
|
// Issuer is the OIDC issuer URL (e.g., "https://management.netbird.io/oauth2")
|
||||||
Issuer string
|
Issuer string
|
||||||
|
// LocalAddress is the management server's local listen address (e.g., ":8080" or "localhost:8080")
|
||||||
|
// Used for internal JWT validation to avoid external network calls
|
||||||
|
LocalAddress string
|
||||||
// Storage configuration for the IdP database
|
// Storage configuration for the IdP database
|
||||||
Storage EmbeddedStorageConfig
|
Storage EmbeddedStorageConfig
|
||||||
// DashboardRedirectURIs are the OAuth2 redirect URIs for the dashboard client
|
// DashboardRedirectURIs are the OAuth2 redirect URIs for the dashboard client
|
||||||
@@ -146,7 +150,12 @@ var _ OAuthConfigProvider = (*EmbeddedIdPManager)(nil)
|
|||||||
// OAuthConfigProvider defines the interface for OAuth configuration needed by auth flows.
|
// OAuthConfigProvider defines the interface for OAuth configuration needed by auth flows.
|
||||||
type OAuthConfigProvider interface {
|
type OAuthConfigProvider interface {
|
||||||
GetIssuer() string
|
GetIssuer() string
|
||||||
|
// GetKeysLocation returns the public JWKS endpoint URL (uses external issuer URL)
|
||||||
GetKeysLocation() string
|
GetKeysLocation() string
|
||||||
|
// GetLocalKeysLocation returns the localhost JWKS endpoint URL for internal use.
|
||||||
|
// Management server has embedded Dex and can validate tokens via localhost,
|
||||||
|
// avoiding external network calls and DNS resolution issues during startup.
|
||||||
|
GetLocalKeysLocation() string
|
||||||
GetClientIDs() []string
|
GetClientIDs() []string
|
||||||
GetUserIDClaim() string
|
GetUserIDClaim() string
|
||||||
GetTokenEndpoint() string
|
GetTokenEndpoint() string
|
||||||
@@ -500,6 +509,22 @@ func (m *EmbeddedIdPManager) GetKeysLocation() string {
|
|||||||
return m.provider.GetKeysLocation()
|
return m.provider.GetKeysLocation()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLocalKeysLocation returns the localhost JWKS endpoint URL for internal token validation.
|
||||||
|
// Uses the LocalAddress from config (management server's listen address) since embedded Dex
|
||||||
|
// is served by the management HTTP server, not a standalone Dex server.
|
||||||
|
func (m *EmbeddedIdPManager) GetLocalKeysLocation() string {
|
||||||
|
addr := m.config.LocalAddress
|
||||||
|
if addr == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Construct localhost URL from listen address
|
||||||
|
// addr is in format ":port" or "host:port" or "localhost:port"
|
||||||
|
if strings.HasPrefix(addr, ":") {
|
||||||
|
return fmt.Sprintf("http://localhost%s/oauth2/keys", addr)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("http://%s/oauth2/keys", addr)
|
||||||
|
}
|
||||||
|
|
||||||
// GetClientIDs returns the OAuth2 client IDs configured for this provider.
|
// GetClientIDs returns the OAuth2 client IDs configured for this provider.
|
||||||
func (m *EmbeddedIdPManager) GetClientIDs() []string {
|
func (m *EmbeddedIdPManager) GetClientIDs() []string {
|
||||||
return []string{staticClientDashboard, staticClientCLI}
|
return []string{staticClientDashboard, staticClientCLI}
|
||||||
|
|||||||
@@ -247,3 +247,61 @@ func TestEmbeddedIdPManager_UserIDFormat_MatchesJWT(t *testing.T) {
|
|||||||
t.Logf(" Raw UUID: %s", rawUserID)
|
t.Logf(" Raw UUID: %s", rawUserID)
|
||||||
t.Logf(" Connector: %s", connectorID)
|
t.Logf(" Connector: %s", connectorID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
localAddress string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "localhost with port",
|
||||||
|
localAddress: "localhost:8080",
|
||||||
|
expected: "http://localhost:8080/oauth2/keys",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "localhost with https port",
|
||||||
|
localAddress: "localhost:443",
|
||||||
|
expected: "http://localhost:443/oauth2/keys",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port only format",
|
||||||
|
localAddress: ":8080",
|
||||||
|
expected: "http://localhost:8080/oauth2/keys",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty address",
|
||||||
|
localAddress: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
config := &EmbeddedIdPConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Issuer: "http://localhost:5556/dex",
|
||||||
|
LocalAddress: tt.localAddress,
|
||||||
|
Storage: EmbeddedStorageConfig{
|
||||||
|
Type: "sqlite3",
|
||||||
|
Config: EmbeddedStorageTypeConfig{
|
||||||
|
File: filepath.Join(tmpDir, "dex-"+tt.name+".db"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = manager.Stop(ctx) }()
|
||||||
|
|
||||||
|
result := manager.GetLocalKeysLocation()
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user