Validate OIDC issuer when creating or updating (#5074)

This commit is contained in:
Misha Bragin
2026-01-09 09:45:43 -05:00
committed by GitHub
parent f7967f9ae3
commit 614e7d5b90
3 changed files with 191 additions and 10 deletions

View File

@@ -2,7 +2,13 @@ package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/dexidp/dex/storage"
"github.com/rs/xid"
@@ -17,6 +23,69 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
// oidcProviderJSON represents the OpenID Connect discovery document
type oidcProviderJSON struct {
Issuer string `json:"issuer"`
}
// validateOIDCIssuer validates the OIDC issuer by fetching the OpenID configuration
// and verifying that the returned issuer matches the configured one.
func validateOIDCIssuer(ctx context.Context, issuer string) error {
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
if err != nil {
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
}
resp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("%w: unable to read response body: %v", types.ErrIdentityProviderIssuerUnreachable, err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("%w: %s: %s", types.ErrIdentityProviderIssuerUnreachable, resp.Status, body)
}
var p oidcProviderJSON
if err := json.Unmarshal(body, &p); err != nil {
return fmt.Errorf("%w: failed to decode provider discovery object: %v", types.ErrIdentityProviderIssuerUnreachable, err)
}
if p.Issuer != issuer {
return fmt.Errorf("%w: expected %q got %q", types.ErrIdentityProviderIssuerMismatch, issuer, p.Issuer)
}
return nil
}
// validateIdentityProviderConfig validates the identity provider configuration including
// basic validation and OIDC issuer verification.
func validateIdentityProviderConfig(ctx context.Context, idpConfig *types.IdentityProvider) error {
if err := idpConfig.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
// Validate the issuer by calling the OIDC discovery endpoint
if idpConfig.Issuer != "" {
if err := validateOIDCIssuer(ctx, idpConfig.Issuer); err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
}
return nil
}
// GetIdentityProviders returns all identity providers for an account
func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
@@ -82,8 +151,8 @@ func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, acc
return nil, status.NewPermissionDeniedError()
}
if err := idpConfig.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
return nil, err
}
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
@@ -119,8 +188,8 @@ func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, acc
return nil, status.NewPermissionDeniedError()
}
if err := idpConfig.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
return nil, err
}
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)

View File

@@ -2,6 +2,10 @@ package server
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
@@ -200,3 +204,109 @@ func TestDefaultAccountManager_UpdateIdentityProvider_Validation(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "name is required")
}
func TestValidateOIDCIssuer(t *testing.T) {
tests := []struct {
name string
setupServer func() *httptest.Server
expectedErr error
expectedErrMsg string
}{
{
name: "issuer mismatch",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := oidcProviderJSON{Issuer: "https://different-issuer.com"}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
},
expectedErr: types.ErrIdentityProviderIssuerMismatch,
expectedErrMsg: "does not match",
},
{
name: "server returns non-200 status",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte("not found"))
}))
},
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
expectedErrMsg: "404",
},
{
name: "server returns invalid JSON",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte("invalid json"))
}))
},
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
expectedErrMsg: "failed to decode",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := tt.setupServer()
defer server.Close()
err := validateOIDCIssuer(context.Background(), server.URL)
require.Error(t, err)
assert.True(t, errors.Is(err, tt.expectedErr), "expected error %v, got %v", tt.expectedErr, err)
if tt.expectedErrMsg != "" {
assert.Contains(t, err.Error(), tt.expectedErrMsg)
}
})
}
}
func TestValidateOIDCIssuer_Success(t *testing.T) {
// Create a server that returns its own URL as the issuer
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
http.NotFound(w, r)
return
}
resp := oidcProviderJSON{Issuer: server.URL}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
err := validateOIDCIssuer(context.Background(), server.URL)
require.NoError(t, err)
}
func TestValidateOIDCIssuer_UnreachableServer(t *testing.T) {
// Use a URL that will definitely fail to connect
err := validateOIDCIssuer(context.Background(), "http://localhost:59999")
require.Error(t, err)
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerUnreachable))
}
func TestValidateOIDCIssuer_TrailingSlash(t *testing.T) {
// Test that trailing slashes are handled correctly
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
http.NotFound(w, r)
return
}
// Return issuer without trailing slash
resp := oidcProviderJSON{Issuer: server.URL}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
// Pass issuer with trailing slash
err := validateOIDCIssuer(context.Background(), server.URL+"/")
// This should fail because the issuer returned doesn't have trailing slash
require.Error(t, err)
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerMismatch))
}

View File

@@ -7,12 +7,14 @@ import (
// Identity provider validation errors
var (
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
ErrIdentityProviderIssuerUnreachable = errors.New("identity provider issuer is unreachable")
ErrIdentityProviderIssuerMismatch = errors.New("identity provider issuer does not match the issuer returned by the provider")
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
)
// IdentityProviderType is the type of identity provider