mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
Validate OIDC issuer when creating or updating (#5074)
This commit is contained in:
@@ -2,7 +2,13 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/rs/xid"
|
||||
@@ -17,6 +23,69 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// oidcProviderJSON represents the OpenID Connect discovery document
|
||||
type oidcProviderJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
}
|
||||
|
||||
// validateOIDCIssuer validates the OIDC issuer by fetching the OpenID configuration
|
||||
// and verifying that the returned issuer matches the configured one.
|
||||
func validateOIDCIssuer(ctx context.Context, issuer string) error {
|
||||
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: unable to read response body: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("%w: %s: %s", types.ErrIdentityProviderIssuerUnreachable, resp.Status, body)
|
||||
}
|
||||
|
||||
var p oidcProviderJSON
|
||||
if err := json.Unmarshal(body, &p); err != nil {
|
||||
return fmt.Errorf("%w: failed to decode provider discovery object: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||
}
|
||||
|
||||
if p.Issuer != issuer {
|
||||
return fmt.Errorf("%w: expected %q got %q", types.ErrIdentityProviderIssuerMismatch, issuer, p.Issuer)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateIdentityProviderConfig validates the identity provider configuration including
|
||||
// basic validation and OIDC issuer verification.
|
||||
func validateIdentityProviderConfig(ctx context.Context, idpConfig *types.IdentityProvider) error {
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
|
||||
// Validate the issuer by calling the OIDC discovery endpoint
|
||||
if idpConfig.Issuer != "" {
|
||||
if err := validateOIDCIssuer(ctx, idpConfig.Issuer); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIdentityProviders returns all identity providers for an account
|
||||
func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
|
||||
@@ -82,8 +151,8 @@ func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, acc
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
@@ -119,8 +188,8 @@ func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, acc
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
|
||||
@@ -2,6 +2,10 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@@ -200,3 +204,109 @@ func TestDefaultAccountManager_UpdateIdentityProvider_Validation(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "name is required")
|
||||
}
|
||||
|
||||
func TestValidateOIDCIssuer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupServer func() *httptest.Server
|
||||
expectedErr error
|
||||
expectedErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "issuer mismatch",
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := oidcProviderJSON{Issuer: "https://different-issuer.com"}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
},
|
||||
expectedErr: types.ErrIdentityProviderIssuerMismatch,
|
||||
expectedErrMsg: "does not match",
|
||||
},
|
||||
{
|
||||
name: "server returns non-200 status",
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
_, _ = w.Write([]byte("not found"))
|
||||
}))
|
||||
},
|
||||
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
|
||||
expectedErrMsg: "404",
|
||||
},
|
||||
{
|
||||
name: "server returns invalid JSON",
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte("invalid json"))
|
||||
}))
|
||||
},
|
||||
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
|
||||
expectedErrMsg: "failed to decode",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := tt.setupServer()
|
||||
defer server.Close()
|
||||
|
||||
err := validateOIDCIssuer(context.Background(), server.URL)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, tt.expectedErr), "expected error %v, got %v", tt.expectedErr, err)
|
||||
if tt.expectedErrMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.expectedErrMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOIDCIssuer_Success(t *testing.T) {
|
||||
// Create a server that returns its own URL as the issuer
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
resp := oidcProviderJSON{Issuer: server.URL}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
err := validateOIDCIssuer(context.Background(), server.URL)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateOIDCIssuer_UnreachableServer(t *testing.T) {
|
||||
// Use a URL that will definitely fail to connect
|
||||
err := validateOIDCIssuer(context.Background(), "http://localhost:59999")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerUnreachable))
|
||||
}
|
||||
|
||||
func TestValidateOIDCIssuer_TrailingSlash(t *testing.T) {
|
||||
// Test that trailing slashes are handled correctly
|
||||
var server *httptest.Server
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
// Return issuer without trailing slash
|
||||
resp := oidcProviderJSON{Issuer: server.URL}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Pass issuer with trailing slash
|
||||
err := validateOIDCIssuer(context.Background(), server.URL+"/")
|
||||
// This should fail because the issuer returned doesn't have trailing slash
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerMismatch))
|
||||
}
|
||||
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
|
||||
// Identity provider validation errors
|
||||
var (
|
||||
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
|
||||
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
|
||||
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
|
||||
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
|
||||
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
|
||||
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
|
||||
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
|
||||
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
|
||||
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
|
||||
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
|
||||
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
|
||||
ErrIdentityProviderIssuerUnreachable = errors.New("identity provider issuer is unreachable")
|
||||
ErrIdentityProviderIssuerMismatch = errors.New("identity provider issuer does not match the issuer returned by the provider")
|
||||
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
|
||||
)
|
||||
|
||||
// IdentityProviderType is the type of identity provider
|
||||
|
||||
Reference in New Issue
Block a user