Files
netbird/management/server/identity_provider_test.go
Zoltan Papp 58daa674ef [Management/Client] Trigger debug bundle runs from API/Dashboard (#4592) (#4832)
This PR adds the ability to trigger debug bundle generation remotely from the Management API/Dashboard.
2026-01-19 11:22:16 +01:00

315 lines
10 KiB
Go

package server
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
)
func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
t.Helper()
ctx := context.Background()
dataDir := t.TempDir()
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "", dataDir)
if err != nil {
return nil, nil, err
}
t.Cleanup(cleanUp)
// Create embedded IdP manager
embeddedConfig := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: idp.EmbeddedStorageConfig{
Type: "sqlite3",
Config: idp.EmbeddedStorageTypeConfig{
File: filepath.Join(dataDir, "dex.db"),
},
},
}
idpManager, err := idp.NewEmbeddedIdPManager(ctx, embeddedConfig, nil)
if err != nil {
return nil, nil, err
}
t.Cleanup(func() { _ = idpManager.Stop(ctx) })
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
if err != nil {
return nil, nil, err
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
UpdateExtraSettings(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(false, nil).
AnyTimes()
permissionsManager := permissions.NewManager(testStore)
peersManager := peers.NewManager(testStore, permissionsManager)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, testStore)
networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peersManager), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, nil, err
}
return manager, updateManager, nil
}
func TestDefaultAccountManager_CreateIdentityProvider_Validation(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err)
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err)
testCases := []struct {
name string
idp *types.IdentityProvider
expectError bool
errorMsg string
}{
{
name: "Missing Name",
idp: &types.IdentityProvider{
Type: types.IdentityProviderTypeOIDC,
Issuer: "https://issuer.example.com",
ClientID: "client-id",
},
expectError: true,
errorMsg: "name is required",
},
{
name: "Missing Type",
idp: &types.IdentityProvider{
Name: "Test IDP",
Issuer: "https://issuer.example.com",
ClientID: "client-id",
},
expectError: true,
errorMsg: "type is required",
},
{
name: "Missing Issuer",
idp: &types.IdentityProvider{
Name: "Test IDP",
Type: types.IdentityProviderTypeOIDC,
ClientID: "client-id",
},
expectError: true,
errorMsg: "issuer is required",
},
{
name: "Missing ClientID",
idp: &types.IdentityProvider{
Name: "Test IDP",
Type: types.IdentityProviderTypeOIDC,
Issuer: "https://issuer.example.com",
},
expectError: true,
errorMsg: "client ID is required",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := manager.CreateIdentityProvider(context.Background(), account.Id, userID, tc.idp)
if tc.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.errorMsg)
}
})
}
}
func TestDefaultAccountManager_GetIdentityProviders(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err)
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err)
// Should return empty list (stub implementation)
providers, err := manager.GetIdentityProviders(context.Background(), account.Id, userID)
require.NoError(t, err)
assert.Empty(t, providers)
}
func TestDefaultAccountManager_GetIdentityProvider_NotFound(t *testing.T) {
manager, _, err := createManagerWithEmbeddedIdP(t)
require.NoError(t, err)
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err)
// Should return not found error when identity provider doesn't exist
_, err = manager.GetIdentityProvider(context.Background(), account.Id, "any-id", userID)
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
}
func TestDefaultAccountManager_UpdateIdentityProvider_Validation(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err)
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err)
// Should fail validation before reaching "not implemented" error
invalidIDP := &types.IdentityProvider{
Name: "", // Empty name should fail validation
}
_, err = manager.UpdateIdentityProvider(context.Background(), account.Id, "some-id", userID, invalidIDP)
require.Error(t, err)
assert.Contains(t, err.Error(), "name is required")
}
func TestValidateOIDCIssuer(t *testing.T) {
tests := []struct {
name string
setupServer func() *httptest.Server
expectedErr error
expectedErrMsg string
}{
{
name: "issuer mismatch",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := oidcProviderJSON{Issuer: "https://different-issuer.com"}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
},
expectedErr: types.ErrIdentityProviderIssuerMismatch,
expectedErrMsg: "does not match",
},
{
name: "server returns non-200 status",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte("not found"))
}))
},
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
expectedErrMsg: "404",
},
{
name: "server returns invalid JSON",
setupServer: func() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte("invalid json"))
}))
},
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
expectedErrMsg: "failed to decode",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := tt.setupServer()
defer server.Close()
err := validateOIDCIssuer(context.Background(), server.URL)
require.Error(t, err)
assert.True(t, errors.Is(err, tt.expectedErr), "expected error %v, got %v", tt.expectedErr, err)
if tt.expectedErrMsg != "" {
assert.Contains(t, err.Error(), tt.expectedErrMsg)
}
})
}
}
func TestValidateOIDCIssuer_Success(t *testing.T) {
// Create a server that returns its own URL as the issuer
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
http.NotFound(w, r)
return
}
resp := oidcProviderJSON{Issuer: server.URL}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
err := validateOIDCIssuer(context.Background(), server.URL)
require.NoError(t, err)
}
func TestValidateOIDCIssuer_UnreachableServer(t *testing.T) {
// Use a URL that will definitely fail to connect
err := validateOIDCIssuer(context.Background(), "http://localhost:59999")
require.Error(t, err)
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerUnreachable))
}
func TestValidateOIDCIssuer_TrailingSlash(t *testing.T) {
// Test that trailing slashes are handled correctly
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/.well-known/openid-configuration" {
http.NotFound(w, r)
return
}
// Return issuer without trailing slash
resp := oidcProviderJSON{Issuer: server.URL}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
// Pass issuer with trailing slash
err := validateOIDCIssuer(context.Background(), server.URL+"/")
// This should fail because the issuer returned doesn't have trailing slash
require.Error(t, err)
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerMismatch))
}