Files
netbird/management/internals/shared/grpc/validate_session_test.go
2026-03-24 15:37:31 +01:00

385 lines
12 KiB
Go

//go:build integration
package grpc
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
type validateSessionTestSetup struct {
proxyService *ProxyServiceServer
store store.Store
cleanup func()
}
func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
t.Helper()
ctx := context.Background()
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
require.NoError(t, err)
serviceManager := &testValidateSessionServiceManager{store: testStore}
usersManager := &testValidateSessionUsersManager{store: testStore}
proxyManager := &testValidateSessionProxyManager{}
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)
return &validateSessionTestSetup{
proxyService: proxyService,
store: testStore,
cleanup: storeCleanup,
}
}
func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
pubKey, privKey := generateSessionKeyPair(t)
testProxy := &service.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
Domain: "test-proxy.example.com",
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
},
},
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &service.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com",
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
},
}
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
}
func generateSessionKeyPair(t *testing.T) (string, string) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(pub), base64.StdEncoding.EncodeToString(priv)
}
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
t.Helper()
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
require.NoError(t, err)
return token
}
func TestValidateSession_UserAllowed(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.True(t, resp.Valid, "User should be allowed access")
assert.Equal(t, "allowedUserId", resp.UserId)
assert.Empty(t, resp.DeniedReason)
}
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "restricted-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "User not in group should be denied")
assert.Equal(t, "not_in_group", resp.DeniedReason)
assert.Equal(t, "nonGroupUserId", resp.UserId)
}
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "User in different account should be denied")
assert.Equal(t, "account_mismatch", resp.DeniedReason)
}
func TestValidateSession_UserNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Non-existent user should be denied")
assert.Equal(t, "user_not_found", resp.DeniedReason)
}
func TestValidateSession_ProxyNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "unknown-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Unknown proxy should be denied")
assert.Equal(t, "service_not_found", resp.DeniedReason)
}
func TestValidateSession_InvalidToken(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: "invalid-token",
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Invalid token should be denied")
assert.Equal(t, "invalid_token", resp.DeniedReason)
}
func TestValidateSession_MissingDomain(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
SessionToken: "some-token",
})
require.NoError(t, err)
assert.False(t, resp.Valid)
assert.Contains(t, resp.DeniedReason, "missing")
}
func TestValidateSession_MissingToken(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
})
require.NoError(t, err)
assert.False(t, resp.Valid)
assert.Contains(t, resp.DeniedReason, "missing")
}
type testValidateSessionServiceManager struct {
store store.Store
}
func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
return nil
}
func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
func (m *testValidateSessionServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string) error {
return nil
}
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
return 0, nil
}
func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
return true, nil
}
func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error {
return nil
}
type testValidateSessionUsersManager struct {
store store.Store
}
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
}