diff --git a/idp/dex/logrus_handler.go b/idp/dex/logrus_handler.go new file mode 100644 index 000000000..d911cb417 --- /dev/null +++ b/idp/dex/logrus_handler.go @@ -0,0 +1,113 @@ +package dex + +import ( + "context" + "log/slog" + + "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/formatter" +) + +// LogrusHandler is an slog.Handler that delegates to logrus. +// This allows Dex to use the same log format as the rest of NetBird. +type LogrusHandler struct { + logger *logrus.Logger + attrs []slog.Attr + groups []string +} + +// NewLogrusHandler creates a new slog handler that wraps logrus with NetBird's text formatter. +func NewLogrusHandler(level slog.Level) *LogrusHandler { + logger := logrus.New() + formatter.SetTextFormatter(logger) + + // Map slog level to logrus level + switch level { + case slog.LevelDebug: + logger.SetLevel(logrus.DebugLevel) + case slog.LevelInfo: + logger.SetLevel(logrus.InfoLevel) + case slog.LevelWarn: + logger.SetLevel(logrus.WarnLevel) + case slog.LevelError: + logger.SetLevel(logrus.ErrorLevel) + default: + logger.SetLevel(logrus.WarnLevel) + } + + return &LogrusHandler{logger: logger} +} + +// Enabled reports whether the handler handles records at the given level. +func (h *LogrusHandler) Enabled(_ context.Context, level slog.Level) bool { + switch level { + case slog.LevelDebug: + return h.logger.IsLevelEnabled(logrus.DebugLevel) + case slog.LevelInfo: + return h.logger.IsLevelEnabled(logrus.InfoLevel) + case slog.LevelWarn: + return h.logger.IsLevelEnabled(logrus.WarnLevel) + case slog.LevelError: + return h.logger.IsLevelEnabled(logrus.ErrorLevel) + default: + return true + } +} + +// Handle handles the Record. +func (h *LogrusHandler) Handle(_ context.Context, r slog.Record) error { + fields := make(logrus.Fields) + + // Add pre-set attributes + for _, attr := range h.attrs { + fields[attr.Key] = attr.Value.Any() + } + + // Add record attributes + r.Attrs(func(attr slog.Attr) bool { + fields[attr.Key] = attr.Value.Any() + return true + }) + + entry := h.logger.WithFields(fields) + + switch r.Level { + case slog.LevelDebug: + entry.Debug(r.Message) + case slog.LevelInfo: + entry.Info(r.Message) + case slog.LevelWarn: + entry.Warn(r.Message) + case slog.LevelError: + entry.Error(r.Message) + default: + entry.Info(r.Message) + } + + return nil +} + +// WithAttrs returns a new Handler with the given attributes added. +func (h *LogrusHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + newAttrs := make([]slog.Attr, len(h.attrs)+len(attrs)) + copy(newAttrs, h.attrs) + copy(newAttrs[len(h.attrs):], attrs) + return &LogrusHandler{ + logger: h.logger, + attrs: newAttrs, + groups: h.groups, + } +} + +// WithGroup returns a new Handler with the given group appended to the receiver's groups. +func (h *LogrusHandler) WithGroup(name string) slog.Handler { + newGroups := make([]string, len(h.groups)+1) + copy(newGroups, h.groups) + newGroups[len(h.groups)] = name + return &LogrusHandler{ + logger: h.logger, + attrs: h.attrs, + groups: newGroups, + } +} diff --git a/idp/dex/provider.go b/idp/dex/provider.go index 09713a226..fae682959 100644 --- a/idp/dex/provider.go +++ b/idp/dex/provider.go @@ -130,7 +130,21 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) { // NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) { - logger := slog.New(slog.NewTextHandler(os.Stderr, nil)) + // Configure log level from config, default to WARN to avoid logging sensitive data (emails) + logLevel := slog.LevelWarn + if yamlConfig.Logger.Level != "" { + switch strings.ToLower(yamlConfig.Logger.Level) { + case "debug": + logLevel = slog.LevelDebug + case "info": + logLevel = slog.LevelInfo + case "warn", "warning": + logLevel = slog.LevelWarn + case "error": + logLevel = slog.LevelError + } + } + logger := slog.New(NewLogrusHandler(logLevel)) stor, err := yamlConfig.Storage.OpenStorage(logger) if err != nil { diff --git a/management/cmd/management.go b/management/cmd/management.go index 5391b0866..9dbd4a6d4 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -190,6 +190,9 @@ func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error { // Enable user deletion from IDP by default if EmbeddedIdP is enabled userDeleteFromIDPEnabled = true + // Set LocalAddress for embedded IdP if enabled, used for internal JWT validation + cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort) + // Ensure HttpConfig exists if cfg.HttpConfig == nil { cfg.HttpConfig = &nbconfig.HttpServerConfig{} diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 688ae5241..9f35d436f 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -68,7 +68,8 @@ func (s *BaseServer) AuthManager() auth.Manager { if len(audiences) > 0 { audience = audiences[0] // Use the first client ID as the primary audience } - keysLocation = oauthProvider.GetKeysLocation() + // Use localhost keys location for internal validation (management has embedded Dex) + keysLocation = oauthProvider.GetLocalKeysLocation() signingKeyRefreshEnabled = true issuer = oauthProvider.GetIssuer() userIDClaim = oauthProvider.GetUserIDClaim() diff --git a/management/server/idp/embedded.go b/management/server/idp/embedded.go index 963b5ae3d..7b8e5033c 100644 --- a/management/server/idp/embedded.go +++ b/management/server/idp/embedded.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "strings" "github.com/dexidp/dex/storage" "github.com/google/uuid" @@ -27,8 +28,11 @@ const ( type EmbeddedIdPConfig struct { // Enabled indicates whether the embedded IDP is enabled Enabled bool - // Issuer is the OIDC issuer URL (e.g., "http://localhost:3002/oauth2") + // Issuer is the OIDC issuer URL (e.g., "https://management.netbird.io/oauth2") Issuer string + // LocalAddress is the management server's local listen address (e.g., ":8080" or "localhost:8080") + // Used for internal JWT validation to avoid external network calls + LocalAddress string // Storage configuration for the IdP database Storage EmbeddedStorageConfig // DashboardRedirectURIs are the OAuth2 redirect URIs for the dashboard client @@ -146,7 +150,12 @@ var _ OAuthConfigProvider = (*EmbeddedIdPManager)(nil) // OAuthConfigProvider defines the interface for OAuth configuration needed by auth flows. type OAuthConfigProvider interface { GetIssuer() string + // GetKeysLocation returns the public JWKS endpoint URL (uses external issuer URL) GetKeysLocation() string + // GetLocalKeysLocation returns the localhost JWKS endpoint URL for internal use. + // Management server has embedded Dex and can validate tokens via localhost, + // avoiding external network calls and DNS resolution issues during startup. + GetLocalKeysLocation() string GetClientIDs() []string GetUserIDClaim() string GetTokenEndpoint() string @@ -500,6 +509,22 @@ func (m *EmbeddedIdPManager) GetKeysLocation() string { return m.provider.GetKeysLocation() } +// GetLocalKeysLocation returns the localhost JWKS endpoint URL for internal token validation. +// Uses the LocalAddress from config (management server's listen address) since embedded Dex +// is served by the management HTTP server, not a standalone Dex server. +func (m *EmbeddedIdPManager) GetLocalKeysLocation() string { + addr := m.config.LocalAddress + if addr == "" { + return "" + } + // Construct localhost URL from listen address + // addr is in format ":port" or "host:port" or "localhost:port" + if strings.HasPrefix(addr, ":") { + return fmt.Sprintf("http://localhost%s/oauth2/keys", addr) + } + return fmt.Sprintf("http://%s/oauth2/keys", addr) +} + // GetClientIDs returns the OAuth2 client IDs configured for this provider. func (m *EmbeddedIdPManager) GetClientIDs() []string { return []string{staticClientDashboard, staticClientCLI} diff --git a/management/server/idp/embedded_test.go b/management/server/idp/embedded_test.go index cfd9c2b54..04e3f0699 100644 --- a/management/server/idp/embedded_test.go +++ b/management/server/idp/embedded_test.go @@ -247,3 +247,61 @@ func TestEmbeddedIdPManager_UserIDFormat_MatchesJWT(t *testing.T) { t.Logf(" Raw UUID: %s", rawUserID) t.Logf(" Connector: %s", connectorID) } + +func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) { + ctx := context.Background() + + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + tests := []struct { + name string + localAddress string + expected string + }{ + { + name: "localhost with port", + localAddress: "localhost:8080", + expected: "http://localhost:8080/oauth2/keys", + }, + { + name: "localhost with https port", + localAddress: "localhost:443", + expected: "http://localhost:443/oauth2/keys", + }, + { + name: "port only format", + localAddress: ":8080", + expected: "http://localhost:8080/oauth2/keys", + }, + { + name: "empty address", + localAddress: "", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAddress: tt.localAddress, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(tmpDir, "dex-"+tt.name+".db"), + }, + }, + } + + manager, err := NewEmbeddedIdPManager(ctx, config, nil) + require.NoError(t, err) + defer func() { _ = manager.Stop(ctx) }() + + result := manager.GetLocalKeysLocation() + assert.Equal(t, tt.expected, result) + }) + } +}