mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
[management] use the cache for the pkce state (#5516)
This commit is contained in:
@@ -423,8 +423,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
t.Helper()
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
t.Cleanup(srv.Close)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
return srv
|
||||
}
|
||||
|
||||
@@ -703,8 +704,9 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
t.Cleanup(proxySrv.Close)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
@@ -1134,8 +1136,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
t.Cleanup(proxySrv.Close)
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -168,7 +168,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
@@ -203,6 +203,16 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore {
|
||||
return Create(s, func() *nbgrpc.PKCEVerifierStore {
|
||||
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create PKCE verifier store: %v", err)
|
||||
}
|
||||
return pkceStore
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AccessLogsManager() accesslogs.Manager {
|
||||
return Create(s, func() accesslogs.Manager {
|
||||
accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager())
|
||||
|
||||
@@ -248,7 +248,6 @@ func (s *BaseServer) Stop() error {
|
||||
_ = s.certManager.Listener().Close()
|
||||
}
|
||||
s.GRPCServer().Stop()
|
||||
s.ReverseProxyGRPCServer().Close()
|
||||
if s.proxyAuthClose != nil {
|
||||
s.proxyAuthClose()
|
||||
s.proxyAuthClose = nil
|
||||
|
||||
61
management/internals/shared/grpc/pkce_verifier.go
Normal file
61
management/internals/shared/grpc/pkce_verifier.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/eko/gocache/lib/v4/cache"
|
||||
"github.com/eko/gocache/lib/v4/store"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
)
|
||||
|
||||
// PKCEVerifierStore manages PKCE verifiers for OAuth flows.
|
||||
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
|
||||
type PKCEVerifierStore struct {
|
||||
cache *cache.Cache[string]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection
|
||||
func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) {
|
||||
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cache store: %w", err)
|
||||
}
|
||||
|
||||
return &PKCEVerifierStore{
|
||||
cache: cache.New[string](cacheStore),
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Store saves a PKCE verifier associated with an OAuth state parameter.
|
||||
// The verifier is stored with the specified TTL and will be automatically deleted after expiration.
|
||||
func (s *PKCEVerifierStore) Store(state, verifier string, ttl time.Duration) error {
|
||||
if err := s.cache.Set(s.ctx, state, verifier, store.WithExpiration(ttl)); err != nil {
|
||||
return fmt.Errorf("failed to store PKCE verifier: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Stored PKCE verifier for state (expires in %s)", ttl)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAndDelete retrieves and removes a PKCE verifier for the given state.
|
||||
// Returns the verifier and true if found, or empty string and false if not found.
|
||||
// This enforces single-use semantics for PKCE verifiers.
|
||||
func (s *PKCEVerifierStore) LoadAndDelete(state string) (string, bool) {
|
||||
verifier, err := s.cache.Get(s.ctx, state)
|
||||
if err != nil {
|
||||
log.Debugf("PKCE verifier not found for state")
|
||||
return "", false
|
||||
}
|
||||
|
||||
if err := s.cache.Delete(s.ctx, state); err != nil {
|
||||
log.Warnf("Failed to delete PKCE verifier for state: %v", err)
|
||||
}
|
||||
|
||||
return verifier, true
|
||||
}
|
||||
@@ -82,20 +82,12 @@ type ProxyServiceServer struct {
|
||||
// OIDC configuration for proxy authentication
|
||||
oidcConfig ProxyOIDCConfig
|
||||
|
||||
// TODO: use database to store these instead?
|
||||
// pkceVerifiers stores PKCE code verifiers keyed by OAuth state.
|
||||
// Entries expire after pkceVerifierTTL to prevent unbounded growth.
|
||||
pkceVerifiers sync.Map
|
||||
pkceCleanupCancel context.CancelFunc
|
||||
// Store for PKCE verifiers
|
||||
pkceVerifierStore *PKCEVerifierStore
|
||||
}
|
||||
|
||||
const pkceVerifierTTL = 10 * time.Minute
|
||||
|
||||
type pkceEntry struct {
|
||||
verifier string
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// proxyConnection represents a connected proxy
|
||||
type proxyConnection struct {
|
||||
proxyID string
|
||||
@@ -107,42 +99,21 @@ type proxyConnection struct {
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||
ctx := context.Background()
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
oidcConfig: oidcConfig,
|
||||
tokenStore: tokenStore,
|
||||
pkceVerifierStore: pkceStore,
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
pkceCleanupCancel: cancel,
|
||||
}
|
||||
go s.cleanupPKCEVerifiers(ctx)
|
||||
go s.cleanupStaleProxies(ctx)
|
||||
return s
|
||||
}
|
||||
|
||||
// cleanupPKCEVerifiers periodically removes expired PKCE verifiers.
|
||||
func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
|
||||
ticker := time.NewTicker(pkceVerifierTTL)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.pkceVerifiers.Range(func(key, value any) bool {
|
||||
if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL {
|
||||
s.pkceVerifiers.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
|
||||
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
@@ -159,11 +130,6 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops background goroutines.
|
||||
func (s *ProxyServiceServer) Close() {
|
||||
s.pkceCleanupCancel()
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
|
||||
s.serviceManager = manager
|
||||
}
|
||||
@@ -790,7 +756,10 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
||||
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
|
||||
|
||||
codeVerifier := oauth2.GenerateVerifier()
|
||||
s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()})
|
||||
if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err)
|
||||
return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err)
|
||||
}
|
||||
|
||||
return &proto.GetOIDCURLResponse{
|
||||
Url: (&oauth2.Config{
|
||||
@@ -827,18 +796,10 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
|
||||
// ValidateState validates the state parameter from an OAuth callback.
|
||||
// Returns the original redirect URL if valid, or an error if invalid.
|
||||
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
|
||||
v, ok := s.pkceVerifiers.LoadAndDelete(state)
|
||||
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
|
||||
if !ok {
|
||||
return "", "", errors.New("no verifier for state")
|
||||
}
|
||||
entry, ok := v.(pkceEntry)
|
||||
if !ok {
|
||||
return "", "", errors.New("invalid verifier for state")
|
||||
}
|
||||
if time.Since(entry.createdAt) > pkceVerifierTTL {
|
||||
return "", "", errors.New("PKCE verifier expired")
|
||||
}
|
||||
verifier = entry.verifier
|
||||
|
||||
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
|
||||
parts := strings.Split(state, "|")
|
||||
|
||||
@@ -5,11 +5,10 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"sync"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -94,11 +93,16 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
ctx := context.Background()
|
||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
tokenStore: tokenStore,
|
||||
pkceVerifierStore: pkceStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
@@ -151,11 +155,16 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
ctx := context.Background()
|
||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
tokenStore: tokenStore,
|
||||
pkceVerifierStore: pkceStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
@@ -185,11 +194,16 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
|
||||
ctx := context.Background()
|
||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
tokenStore: tokenStore,
|
||||
pkceVerifierStore: pkceStore,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
|
||||
@@ -241,10 +255,15 @@ func generateState(s *ProxyServiceServer, redirectURL string) string {
|
||||
}
|
||||
|
||||
func TestOAuthState_NeverTheSame(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
oidcConfig: ProxyOIDCConfig{
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
},
|
||||
pkceVerifierStore: pkceStore,
|
||||
}
|
||||
|
||||
redirectURL := "https://app.example.com/callback"
|
||||
@@ -265,31 +284,43 @@ func TestOAuthState_NeverTheSame(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
oidcConfig: ProxyOIDCConfig{
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
},
|
||||
pkceVerifierStore: pkceStore,
|
||||
}
|
||||
|
||||
// Old format had only 2 parts: base64(url)|hmac
|
||||
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||
err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err := s.ValidateState("base64url|hmac")
|
||||
_, _, err = s.ValidateState("base64url|hmac")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid state format")
|
||||
}
|
||||
|
||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
oidcConfig: ProxyOIDCConfig{
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
},
|
||||
pkceVerifierStore: pkceStore,
|
||||
}
|
||||
|
||||
// Store with tampered HMAC
|
||||
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||
err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
||||
_, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid state signature")
|
||||
}
|
||||
|
||||
@@ -41,7 +41,10 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
||||
pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
||||
proxyService.SetServiceManager(serviceManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
|
||||
Reference in New Issue
Block a user