Compare commits

...

5 Commits

Author SHA1 Message Date
pascal
37b9905b68 Merge branch 'main' into feature/log-most-busy-peers 2026-04-23 15:41:26 +02:00
Vlad
b6038e8acd [management] refactor: changeable pat rate limiting (#5946) 2026-04-23 15:13:22 +02:00
pascal
92e53d6319 generic settings overrider 2026-04-21 18:31:53 +02:00
pascal
8a7d78ddf3 make configurable via env 2026-04-21 15:42:44 +02:00
pascal
ea83cbf917 log the most busy peers 2026-04-21 15:27:05 +02:00
14 changed files with 805 additions and 61 deletions

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/keepalive"
cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
@@ -30,9 +31,11 @@ import (
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/settingoverrider"
"github.com/netbirdio/netbird/util/crypt"
)
@@ -72,6 +75,23 @@ func (s *BaseServer) CacheStore() cachestore.StoreInterface {
})
}
// SettingOverrider returns a shared setting overrider backed by Redis.
// Returns a no-op overrider if no Redis address is configured.
func (s *BaseServer) SettingOverrider() *settingoverrider.Overrider {
return Create(s, func() *settingoverrider.Overrider {
redisAddr := nbcache.GetAddrFromEnv()
if redisAddr == "" {
return settingoverrider.NewNoop()
}
o, err := settingoverrider.New(context.Background(), redisAddr)
if err != nil {
log.Fatalf("failed to create setting overrider: %v", err)
}
return o
})
}
func (s *BaseServer) Store() store.Store {
return Create(s, func() store.Store {
store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false)
@@ -109,7 +129,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -117,6 +137,15 @@ func (s *BaseServer) APIHandler() http.Handler {
})
}
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
return Create(s, func() *middleware.APIRateLimiter {
cfg, enabled := middleware.RateLimiterConfigFromEnv()
limiter := middleware.NewAPIRateLimiter(cfg)
limiter.SetEnabled(enabled)
return limiter
})
}
func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server {
trustedPeers := s.Config.ReverseProxy.TrustedPeers

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/settingoverrider"
"github.com/netbirdio/netbird/util/wsproxy"
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
"github.com/netbirdio/netbird/version"
@@ -123,6 +124,15 @@ func (s *BaseServer) Start(ctx context.Context) error {
s.PeersManager()
s.GeoLocationManager()
s.SettingOverrider().Poll(settingoverrider.DefaultInterval, "managementLogLevel", func(value string) error {
level, err := log.ParseLevel(value)
if err != nil {
return fmt.Errorf("parsing log level %q: %w", value, err)
}
log.SetLevel(level)
return nil
})
err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics")
if err != nil {
return fmt.Errorf("failed to expose metrics: %v", err)
@@ -235,6 +245,7 @@ func (s *BaseServer) Stop() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.SettingOverrider().Close()
s.IntegratedValidator().Stop(ctx)
if s.GeoLocationManager() != nil {
_ = s.GeoLocationManager().Stop()

View File

@@ -5,9 +5,6 @@ import (
"fmt"
"net/http"
"net/netip"
"os"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/rs/cors"
@@ -66,14 +63,11 @@ import (
)
const (
apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
apiPrefix = "/api"
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -94,34 +88,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
rpm := 6
if v := os.Getenv(rateLimitingRPMKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
} else {
rpm = value
}
}
burst := 500
if v := os.Getenv(rateLimitingBurstKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
} else {
burst = value
}
}
rateLimitingConfig = &middleware.RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}
if rateLimiter == nil {
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
rateLimiter = middleware.NewAPIRateLimiter(nil)
rateLimiter.SetEnabled(false)
}
authMiddleware := middleware.NewAuthMiddleware(
@@ -129,7 +99,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
rateLimitingConfig,
rateLimiter,
appMetrics.GetMeter(),
)

View File

@@ -43,14 +43,9 @@ func NewAuthMiddleware(
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiterConfig *RateLimiterConfig,
rateLimiter *APIRateLimiter,
meter metric.Meter,
) *AuthMiddleware {
var rateLimiter *APIRateLimiter
if rateLimiterConfig != nil {
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
}
var patUsageTracker *PATUsageTracker
if meter != nil {
var err error
@@ -181,10 +176,8 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
m.patUsageTracker.IncrementUsage(token)
}
if m.rateLimiter != nil && !isTerraformRequest(r) {
if !m.rateLimiter.Allow(token) {
return status.Errorf(status.TooManyRequests, "too many requests")
}
if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
return status.Errorf(status.TooManyRequests, "too many requests")
}
ctx := r.Context()

View File

@@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
disabledLimiter,
nil,
)
@@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
disabledLimiter,
nil,
)

View File

@@ -4,14 +4,27 @@ import (
"context"
"net"
"net/http"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/time/rate"
"github.com/netbirdio/netbird/shared/management/http/util"
)
const (
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
defaultAPIRPM = 6
defaultAPIBurst = 500
)
// RateLimiterConfig holds configuration for the API rate limiter
type RateLimiterConfig struct {
// RequestsPerMinute defines the rate at which tokens are replenished
@@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
}
}
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
rpm := defaultAPIRPM
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
} else {
rpm = value
}
}
if rpm <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
rpm = defaultAPIRPM
}
burst := defaultAPIBurst
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
} else {
burst = value
}
}
if burst <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
burst = defaultAPIBurst
}
return &RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}, os.Getenv(RateLimitingEnabledEnv) == "true"
}
// limiterEntry holds a rate limiter and its last access time
type limiterEntry struct {
limiter *rate.Limiter
@@ -46,6 +96,7 @@ type APIRateLimiter struct {
limiters map[string]*limiterEntry
mu sync.RWMutex
stopChan chan struct{}
enabled atomic.Bool
}
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
@@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
limiters: make(map[string]*limiterEntry),
stopChan: make(chan struct{}),
}
rl.enabled.Store(true)
go rl.cleanupLoop()
return rl
}
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
rl.enabled.Store(enabled)
}
func (rl *APIRateLimiter) Enabled() bool {
return rl.enabled.Load()
}
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
if config == nil {
return
}
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
return
}
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
newBurst := config.Burst
rl.mu.Lock()
rl.config.RequestsPerMinute = config.RequestsPerMinute
rl.config.Burst = newBurst
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
for _, entry := range rl.limiters {
snapshot = append(snapshot, entry.limiter)
}
rl.mu.Unlock()
for _, l := range snapshot {
l.SetLimit(newRPS)
l.SetBurst(newBurst)
}
}
// Allow checks if a request for the given key (token) is allowed
func (rl *APIRateLimiter) Allow(key string) bool {
if !rl.enabled.Load() {
return true
}
limiter := rl.getLimiter(key)
return limiter.Allow()
}
@@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool {
// Wait blocks until the rate limiter allows another request for the given key
// Returns an error if the context is canceled
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
if !rl.enabled.Load() {
return nil
}
limiter := rl.getLimiter(key)
return limiter.Wait(ctx)
}
@@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) {
// Returns 429 Too Many Requests if the rate limit is exceeded.
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !rl.enabled.Load() {
next.ServeHTTP(w, r)
return
}
clientIP := getClientIP(r)
if !rl.Allow(clientIP) {
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)

View File

@@ -1,8 +1,10 @@
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
@@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
// Should be allowed again
assert.True(t, rl.Allow("test-key"))
}
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("key"))
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
rl.SetEnabled(false)
assert.False(t, rl.Enabled())
for i := 0; i < 5; i++ {
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
}
rl.SetEnabled(true)
assert.True(t, rl.Enabled())
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
}
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k1"))
assert.True(t, rl.Allow("k1"))
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
// New burst applies to existing keys in place; bucket refills up to new burst over time,
// but importantly newly-added keys use the updated config immediately.
assert.True(t, rl.Allow("k2"))
for i := 0; i < 9; i++ {
assert.True(t, rl.Allow("k2"))
}
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
}
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
rl.UpdateConfig(nil) // must not panic or zero the config
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
}
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.Reset("k")
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
}
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 600,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
var wg sync.WaitGroup
stop := make(chan struct{})
for i := 0; i < 8; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
key := fmt.Sprintf("k%d", id)
for {
select {
case <-stop:
return
default:
rl.Allow(key)
}
}
}(i)
}
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 200; i++ {
select {
case <-stop:
return
default:
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: float64(30 + (i % 90)),
Burst: 1 + (i % 20),
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
rl.SetEnabled(i%2 == 0)
}
}
}()
time.Sleep(100 * time.Millisecond)
close(stop)
wg.Wait()
}
func TestRateLimiterConfigFromEnv(t *testing.T) {
t.Setenv(RateLimitingEnabledEnv, "true")
t.Setenv(RateLimitingRPMEnv, "42")
t.Setenv(RateLimitingBurstEnv, "7")
cfg, enabled := RateLimiterConfigFromEnv()
assert.True(t, enabled)
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
assert.Equal(t, 7, cfg.Burst)
t.Setenv(RateLimitingEnabledEnv, "false")
_, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
t.Setenv(RateLimitingEnabledEnv, "")
t.Setenv(RateLimitingRPMEnv, "")
t.Setenv(RateLimitingBurstEnv, "")
cfg, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
assert.Equal(t, defaultAPIBurst, cfg.Burst)
t.Setenv(RateLimitingRPMEnv, "0")
t.Setenv(RateLimitingBurstEnv, "-5")
cfg, _ = RateLimiterConfigFromEnv()
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
}

View File

@@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -0,0 +1,120 @@
package settingoverrider
import (
"context"
"errors"
"fmt"
"time"
"github.com/redis/go-redis/v9"
log "github.com/sirupsen/logrus"
)
const (
DefaultInterval = 5 * time.Minute
)
// ApplyFunc is called with the raw Redis string value whenever it changes.
// The function is responsible for parsing and applying the value.
// Return an error to log a warning without stopping the polling loop.
type ApplyFunc func(value string) error
// Overrider holds a shared Redis connection and allows registering
// individual settings that are polled independently.
type Overrider struct {
client *redis.Client
cancel context.CancelFunc
ctx context.Context
noop bool
}
// New creates an Overrider by connecting to Redis at the given address.
// The address should follow the Redis URL format (e.g. "redis://localhost:6379").
func New(ctx context.Context, redisAddr string) (*Overrider, error) {
if redisAddr == "" {
return nil, fmt.Errorf("redis address is empty")
}
options, err := redis.ParseURL(redisAddr)
if err != nil {
return nil, fmt.Errorf("parsing redis address: %w", err)
}
client := redis.NewClient(options)
pingCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
if _, err := client.Ping(pingCtx).Result(); err != nil {
_ = client.Close()
return nil, fmt.Errorf("connecting to redis: %w", err)
}
oCtx, oCancel := context.WithCancel(ctx)
return &Overrider{client: client, cancel: oCancel, ctx: oCtx}, nil
}
// NewNoop returns an Overrider that does nothing.
// Poll calls are silently ignored and Close is a no-op.
func NewNoop() *Overrider {
return &Overrider{noop: true}
}
// Close stops all polling goroutines and closes the underlying Redis client.
func (o *Overrider) Close() error {
if o.noop {
return nil
}
o.cancel()
return o.client.Close()
}
// Poll starts a background goroutine that polls a single Redis key at the given interval
// and calls apply whenever the value changes. The goroutine stops when the Overrider is closed.
func (o *Overrider) Poll(interval time.Duration, redisKey string, apply ApplyFunc) {
if o.noop {
return
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
var lastSeen *string
for {
select {
case <-o.ctx.Done():
log.WithContext(o.ctx).Infof("Stopping settings overrider for key %q", redisKey)
return
case <-ticker.C:
getCtx, cancel := context.WithTimeout(o.ctx, 5*time.Second)
val, err := o.client.Get(getCtx, redisKey).Result()
cancel()
if errors.Is(err, redis.Nil) || val == "" {
continue
}
if err != nil {
if o.ctx.Err() != nil {
return
}
log.WithContext(o.ctx).Errorf("Unable to get setting %q from Redis: %v", redisKey, err)
continue
}
if lastSeen != nil && *lastSeen == val {
continue
}
if err := apply(val); err != nil {
log.WithContext(o.ctx).Warnf("Failed to apply setting %q with value %q: %v", redisKey, val, err)
continue
}
lastSeen = &val
}
}
}()
}

View File

@@ -0,0 +1,111 @@
package settingoverrider
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis"
"github.com/testcontainers/testcontainers-go/wait"
)
func TestPoll_AppliesSettingFromRedis(t *testing.T) {
o, client := setupOverrider(t)
key := "test-setting-key"
require.NoError(t, client.Set(context.Background(), key, "hello", 0).Err())
var applied atomic.Value
o.Poll(100*time.Millisecond, key, func(value string) error {
applied.Store(value)
return nil
})
assert.Eventually(t, func() bool {
v := applied.Load()
return v != nil && v.(string) == "hello"
}, 5*time.Second, 50*time.Millisecond)
}
func TestPoll_IndependentSettings(t *testing.T) {
o, client := setupOverrider(t)
require.NoError(t, client.Set(context.Background(), "key-a", "val-a", 0).Err())
require.NoError(t, client.Set(context.Background(), "key-b", "val-b", 0).Err())
var gotA, gotB atomic.Value
o.Poll(100*time.Millisecond, "key-a", func(v string) error { gotA.Store(v); return nil })
o.Poll(100*time.Millisecond, "key-b", func(v string) error { gotB.Store(v); return nil })
assert.Eventually(t, func() bool {
a, b := gotA.Load(), gotB.Load()
return a != nil && a.(string) == "val-a" && b != nil && b.(string) == "val-b"
}, 5*time.Second, 50*time.Millisecond)
}
func TestPoll_SkipsDuplicateValues(t *testing.T) {
o, client := setupOverrider(t)
key := "test-dedup"
require.NoError(t, client.Set(context.Background(), key, "same", 0).Err())
var count atomic.Int32
o.Poll(100*time.Millisecond, key, func(string) error {
count.Add(1)
return nil
})
// wait for a few ticks
time.Sleep(600 * time.Millisecond)
assert.Equal(t, int32(1), count.Load(), "Apply should be called only once for unchanged value")
}
func setupOverrider(t *testing.T) (*Overrider, *redis.Client) {
t.Helper()
ctx := context.Background()
redisContainer, err := testcontainersredis.RunContainer(ctx,
testcontainers.WithImage("redis:7"),
testcontainers.WithWaitStrategy(
wait.ForListeningPort("6379/tcp"),
),
)
require.NoError(t, err, "Failed to create redis test container")
t.Cleanup(func() {
if err := redisContainer.Terminate(ctx); err != nil {
t.Logf("failed to terminate redis container: %s", err)
}
})
redisURL, err := redisContainer.ConnectionString(ctx)
require.NoError(t, err)
o, err := New(ctx, redisURL)
require.NoError(t, err)
t.Cleanup(func() {
if err := o.Close(); err != nil {
t.Logf("failed to close overrider: %s", err)
}
})
// separate client for test setup (setting keys)
options, err := redis.ParseURL(redisURL)
require.NoError(t, err)
client := redis.NewClient(options)
t.Cleanup(func() {
if err := client.Close(); err != nil {
t.Logf("failed to close redis client: %s", err)
}
})
return o, client
}

View File

@@ -18,7 +18,9 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/settingoverrider"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/signal/proto"
@@ -114,7 +116,24 @@ var (
}
}()
srv, err := server.NewServer(cmd.Context(), metricsServer.Meter)
overrider := settingoverrider.NewNoop()
if redisAddr := cache.GetAddrFromEnv(); redisAddr != "" {
overrider, err = settingoverrider.New(cmd.Context(), redisAddr)
if err != nil {
return fmt.Errorf("failed to create setting overrider: %w", err)
}
defer func() { _ = overrider.Close() }()
}
overrider.Poll(settingoverrider.DefaultInterval, "signalLogLevel", func(value string) error {
level, err := log.ParseLevel(value)
if err != nil {
return fmt.Errorf("parsing log level %q: %w", value, err)
}
log.SetLevel(level)
return nil
})
srv, err := server.NewServer(cmd.Context(), metricsServer.Meter, overrider)
if err != nil {
return fmt.Errorf("creating signal server: %v", err)
}

View File

@@ -0,0 +1,135 @@
package server
import (
"context"
"math"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultSendRateLogInterval = 5 * time.Minute
defaultSendRateTopPercent = 0.95
envSendRateLogInterval = "NB_SIGNAL_SEND_RATE_LOG_INTERVAL"
envSendRateTopPercent = "NB_SIGNAL_SEND_RATE_LOG_TOP_PERCENT"
)
// sendRateTracker tracks per-key message counts and logs the busiest peers periodically.
type sendRateTracker struct {
mu sync.Mutex
counts map[string]int64
// atomic so they can be updated by the setting overrider without locking
intervalNs atomic.Int64
// topPercent stored as float64 bits for atomic access
topPercentBits atomic.Uint64
}
func newSendRateTracker() *sendRateTracker {
interval := defaultSendRateLogInterval
if v := os.Getenv(envSendRateLogInterval); v != "" {
if parsed, err := time.ParseDuration(v); err == nil && parsed > 0 {
interval = parsed
}
}
topPercent := defaultSendRateTopPercent
if v := os.Getenv(envSendRateTopPercent); v != "" {
if parsed, err := strconv.ParseFloat(v, 64); err == nil && parsed > 0 && parsed <= 1 {
topPercent = parsed
}
}
log.Debugf("send rate tracker: interval=%s, top_percent=%.2f", interval, topPercent)
t := &sendRateTracker{
counts: make(map[string]int64),
}
t.intervalNs.Store(int64(interval))
t.topPercentBits.Store(math.Float64bits(topPercent))
return t
}
func (t *sendRateTracker) getInterval() time.Duration {
return time.Duration(t.intervalNs.Load())
}
func (t *sendRateTracker) setInterval(d time.Duration) {
t.intervalNs.Store(int64(d))
}
func (t *sendRateTracker) getTopPercent() float64 {
return math.Float64frombits(t.topPercentBits.Load())
}
func (t *sendRateTracker) setTopPercent(p float64) {
t.topPercentBits.Store(math.Float64bits(p))
}
func (t *sendRateTracker) increment(key string) {
t.mu.Lock()
t.counts[key]++
t.mu.Unlock()
}
// resetAndSnapshot atomically returns current counts and resets the tracker.
func (t *sendRateTracker) resetAndSnapshot() map[string]int64 {
t.mu.Lock()
snap := t.counts
t.counts = make(map[string]int64, len(snap))
t.mu.Unlock()
return snap
}
// logSendRates periodically logs peers in the top percentile of the busiest peer.
func (t *sendRateTracker) logSendRates(ctx context.Context) {
currentInterval := t.getInterval()
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if newInterval := t.getInterval(); newInterval != currentInterval {
currentInterval = newInterval
ticker.Reset(currentInterval)
}
snap := t.resetAndSnapshot()
if len(snap) == 0 {
continue
}
var maxCount int64
for _, count := range snap {
if count > maxCount {
maxCount = count
}
}
topPercent := t.getTopPercent()
threshold := int64(float64(maxCount) * topPercent)
intervalMin := currentInterval.Minutes()
log.Debugf("send rate stats: %d unique peers in last %.0fs, max rate %.1f msg/min",
len(snap), currentInterval.Seconds(), float64(maxCount)/intervalMin)
logged := 0
for key, count := range snap {
if count >= threshold {
log.Debugf("peer [%s] %.1f msg/min", key, float64(count)/intervalMin)
logged++
if logged >= 100 {
break
}
}
}
}
}
}

View File

@@ -0,0 +1,56 @@
package server
import (
"sync"
"testing"
)
func TestSendRateTracker_Increment(t *testing.T) {
tracker := newSendRateTracker()
tracker.increment("peer-a")
tracker.increment("peer-a")
tracker.increment("peer-b")
snap := tracker.resetAndSnapshot()
if snap["peer-a"] != 2 {
t.Errorf("expected peer-a count 2, got %d", snap["peer-a"])
}
if snap["peer-b"] != 1 {
t.Errorf("expected peer-b count 1, got %d", snap["peer-b"])
}
}
func TestSendRateTracker_ResetAndSnapshot_Resets(t *testing.T) {
tracker := newSendRateTracker()
tracker.increment("peer-a")
snap1 := tracker.resetAndSnapshot()
if snap1["peer-a"] != 1 {
t.Fatalf("expected 1, got %d", snap1["peer-a"])
}
snap2 := tracker.resetAndSnapshot()
if len(snap2) != 0 {
t.Errorf("expected empty snapshot after reset, got %v", snap2)
}
}
func TestSendRateTracker_ConcurrentIncrement(t *testing.T) {
tracker := newSendRateTracker()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tracker.increment("peer-x")
}()
}
wg.Wait()
snap := tracker.resetAndSnapshot()
if snap["peer-x"] != 100 {
t.Errorf("expected 100, got %d", snap["peer-x"])
}
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
@@ -17,6 +18,7 @@ import (
"github.com/netbirdio/signal-dispatcher/dispatcher"
"github.com/netbirdio/netbird/shared/settingoverrider"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/signal/peer"
@@ -59,10 +61,12 @@ type Server struct {
successHeader metadata.MD
sendTimeout time.Duration
sendTracker *sendRateTracker
}
// NewServer creates a new Signal server
func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string) (*Server, error) {
func NewServer(ctx context.Context, meter metric.Meter, overrider *settingoverrider.Overrider, metricsPrefix ...string) (*Server, error) {
appMetrics, err := metrics.NewAppMetrics(meter, metricsPrefix...)
if err != nil {
return nil, fmt.Errorf("creating app metrics: %v", err)
@@ -80,14 +84,36 @@ func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string)
sTimeout = parsed
}
tracker := newSendRateTracker()
s := &Server{
dispatcher: d,
registry: peer.NewRegistry(appMetrics),
metrics: appMetrics,
successHeader: metadata.Pairs(proto.HeaderRegistered, "1"),
sendTimeout: sTimeout,
sendTracker: tracker,
}
overrider.Poll(settingoverrider.DefaultInterval, "signalSendRateLogInterval", func(value string) error {
parsed, err := time.ParseDuration(value)
if err != nil || parsed <= 0 {
return fmt.Errorf("invalid send rate log interval %q: %w", value, err)
}
tracker.setInterval(parsed)
return nil
})
overrider.Poll(settingoverrider.DefaultInterval, "signalSendRateTopPercent", func(value string) error {
parsed, err := strconv.ParseFloat(value, 64)
if err != nil || parsed <= 0 || parsed > 1 {
return fmt.Errorf("invalid send rate top percent %q: %w", value, err)
}
tracker.setTopPercent(parsed)
return nil
})
go tracker.logSendRates(ctx)
return s, nil
}
@@ -95,6 +121,8 @@ func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string)
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
s.sendTracker.increment(msg.Key)
if _, found := s.registry.Get(msg.RemoteKey); found {
s.forwardMessageToPeer(ctx, msg)
return &proto.EncryptedMessage{}, nil