Compare commits

..

1 Commits

Author SHA1 Message Date
Ashley Mensah
6b8e40f78d initial commit - workflow yaml, prompts and schemas 2026-04-23 18:48:38 +02:00
19 changed files with 492 additions and 805 deletions

View File

@@ -0,0 +1,26 @@
You are a GitHub issue resolution classifier.
Your job is to decide whether an open GitHub issue is:
- AUTO_CLOSE
- MANUAL_REVIEW
- KEEP_OPEN
Rules:
1. AUTO_CLOSE is only allowed if there is objective, hard evidence:
- a merged linked PR that clearly resolves the issue, or
- an explicit maintainer/member/owner/collaborator comment saying the issue is fixed, resolved, duplicate, or superseded
2. If there is any contradictory later evidence, do NOT AUTO_CLOSE.
3. If evidence is promising but not airtight, choose MANUAL_REVIEW.
4. If the issue still appears active or unresolved, choose KEEP_OPEN.
5. Do not invent evidence.
6. Output valid JSON only.
Maintainer-authoritative roles:
- MEMBER
- OWNER
- COLLABORATOR
Important:
- Later comments outweigh earlier ones.
- A non-maintainer saying "fixed for me" is not enough for AUTO_CLOSE.
- If uncertain, prefer MANUAL_REVIEW or KEEP_OPEN.

View File

@@ -0,0 +1,78 @@
{
"type": "object",
"additionalProperties": false,
"required": [
"decision",
"reason_code",
"confidence",
"hard_signals",
"contradictions",
"summary",
"close_comment",
"manual_review_note"
],
"properties": {
"decision": {
"type": "string",
"enum": ["AUTO_CLOSE", "MANUAL_REVIEW", "KEEP_OPEN"]
},
"reason_code": {
"type": "string",
"enum": [
"resolved_by_merged_pr",
"maintainer_confirmed_resolved",
"duplicate_confirmed",
"superseded_confirmed",
"likely_fixed_but_unconfirmed",
"still_open",
"unclear"
]
},
"confidence": {
"type": "number",
"minimum": 0,
"maximum": 1
},
"hard_signals": {
"type": "array",
"items": {
"type": "object",
"required": ["type", "url"],
"properties": {
"type": {
"type": "string",
"enum": [
"merged_pr",
"maintainer_comment",
"duplicate_reference",
"superseded_reference"
]
},
"url": { "type": "string" }
}
}
},
"contradictions": {
"type": "array",
"items": {
"type": "object",
"required": ["type", "url"],
"properties": {
"type": {
"type": "string",
"enum": [
"reporter_still_broken",
"later_unresolved_comment",
"ambiguous_pr_link",
"other"
]
},
"url": { "type": "string" }
}
}
},
"summary": { "type": "string" },
"close_comment": { "type": "string" },
"manual_review_note": { "type": "string" }
}
}

View File

@@ -0,0 +1,152 @@
import fs from "node:fs/promises";
const decisions = JSON.parse(await fs.readFile("decisions.json", "utf8"));
const dryRun = String(process.env.DRY_RUN).toLowerCase() === "true";
const headers = {
Authorization: `Bearer ${process.env.GH_TOKEN}`,
Accept: "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
};
async function rest(url, method = "GET", body) {
const res = await fetch(url, {
method,
headers,
body: body ? JSON.stringify(body) : undefined
});
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
return res.status === 204 ? null : res.json();
}
async function graphql(query, variables) {
const res = await fetch("https://api.github.com/graphql", {
method: "POST",
headers,
body: JSON.stringify({ query, variables })
});
if (!res.ok) throw new Error(`${res.status}: ${await res.text()}`);
const json = await res.json();
if (json.errors) throw new Error(JSON.stringify(json.errors));
return json.data;
}
async function addLabel(owner, repo, issueNumber, labels) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/labels`,
"POST",
{ labels }
);
}
async function addComment(owner, repo, issueNumber, body) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/comments`,
"POST",
{ body }
);
}
async function closeIssue(owner, repo, issueNumber) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`,
"PATCH",
{ state: "closed", state_reason: "completed" }
);
}
async function getIssueNodeId(owner, repo, issueNumber) {
const issue = await rest(`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`);
return issue.node_id;
}
async function addToProject(issueNodeId) {
const mutation = `
mutation($projectId: ID!, $contentId: ID!) {
addProjectV2ItemById(input: {projectId: $projectId, contentId: $contentId}) {
item { id }
}
}
`;
const data = await graphql(mutation, {
projectId: process.env.PROJECT_ID,
contentId: issueNodeId
});
return data.addProjectV2ItemById.item.id;
}
async function setTextField(itemId, fieldId, value) {
const mutation = `
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
updateProjectV2ItemFieldValue(input: {
projectId: $projectId,
itemId: $itemId,
fieldId: $fieldId,
value: { text: $value }
}) {
projectV2Item { id }
}
}
`;
return graphql(mutation, {
projectId: process.env.PROJECT_ID,
itemId,
fieldId,
value
});
}
for (const d of decisions) {
const [owner, repo] = d.repository.split("/");
if (d.final_decision === "AUTO_CLOSE") {
if (dryRun) continue;
await addLabel(owner, repo, d.issue_number, ["auto-closed-resolved"]);
await addComment(
owner,
repo,
d.issue_number,
d.model.close_comment ||
"This appears resolved based on linked evidence, so were closing it automatically. Reply if this still reproduces and well reopen."
);
await closeIssue(owner, repo, d.issue_number);
}
if (d.final_decision === "MANUAL_REVIEW") {
await addLabel(owner, repo, d.issue_number, ["resolution-candidate"]);
const issueNodeId = await getIssueNodeId(owner, repo, d.issue_number);
const itemId = await addToProject(issueNodeId);
if (process.env.PROJECT_CONFIDENCE_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_CONFIDENCE_FIELD_ID, String(d.model.confidence));
}
if (process.env.PROJECT_REASON_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_REASON_FIELD_ID, d.model.reason_code);
}
if (process.env.PROJECT_EVIDENCE_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_EVIDENCE_FIELD_ID, d.issue_url);
}
if (process.env.PROJECT_LINKED_PR_FIELD_ID) {
const linked = (d.model.hard_signals || []).map(x => x.url).join(", ");
if (linked) {
await setTextField(itemId, process.env.PROJECT_LINKED_PR_FIELD_ID, linked);
}
}
if (process.env.PROJECT_REPO_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_REPO_FIELD_ID, d.repository);
}
await addComment(
owner,
repo,
d.issue_number,
d.model.manual_review_note ||
"This issue looks like a possible resolution candidate, but not with enough certainty for automatic closure. Added to the review queue."
);
}
}

View File

@@ -0,0 +1,125 @@
import fs from "node:fs/promises";
const candidates = JSON.parse(await fs.readFile("candidates.json", "utf8"));
function isMaintainerRole(role) {
return ["MEMBER", "OWNER", "COLLABORATOR"].includes(role || "");
}
function preScore(candidate) {
let score = 0;
const hardSignals = [];
const contradictions = [];
for (const t of candidate.timeline) {
const sourceIssue = t.source?.issue;
if (t.event === "cross-referenced" && sourceIssue?.pull_request?.html_url) {
hardSignals.push({
type: "merged_pr",
url: sourceIssue.html_url
});
score += 40; // provisional until PR merged state is verified
}
if (["referenced", "connected"].includes(t.event)) {
score += 10;
}
}
for (const c of candidate.comments) {
const body = c.body.toLowerCase();
if (
isMaintainerRole(c.author_association) &&
/\b(fixed|resolved|duplicate|superseded|closing)\b/.test(body)
) {
score += 25;
hardSignals.push({
type: "maintainer_comment",
url: c.html_url
});
}
if (/\b(still broken|still happening|not fixed|reproducible)\b/.test(body)) {
score -= 50;
contradictions.push({
type: "later_unresolved_comment",
url: c.html_url
});
}
}
return { score, hardSignals, contradictions };
}
async function callGitHubModel(issuePacket) {
// Replace this stub with the GitHub Models inference call used by your org.
// The workflow already has models: read permission.
return {
decision: "MANUAL_REVIEW",
reason_code: "likely_fixed_but_unconfirmed",
confidence: 0.74,
hard_signals: [],
contradictions: [],
summary: "Potential resolution candidate; evidence is not strong enough to close automatically.",
close_comment: "This appears resolved, so were closing it automatically. Reply if this is still reproducible.",
manual_review_note: "Potential resolution candidate. Please review evidence before closing."
};
}
function enforcePolicy(modelOut, pre) {
const approvedReasons = new Set([
"resolved_by_merged_pr",
"maintainer_confirmed_resolved",
"duplicate_confirmed",
"superseded_confirmed"
]);
const hasHardSignal =
(modelOut.hard_signals || []).some(s =>
["merged_pr", "maintainer_comment", "duplicate_reference", "superseded_reference"].includes(s.type)
) || pre.hardSignals.length > 0;
const hasContradiction =
(modelOut.contradictions || []).length > 0 || pre.contradictions.length > 0;
if (
modelOut.decision === "AUTO_CLOSE" &&
modelOut.confidence >= 0.97 &&
approvedReasons.has(modelOut.reason_code) &&
hasHardSignal &&
!hasContradiction
) {
return "AUTO_CLOSE";
}
if (
modelOut.decision === "MANUAL_REVIEW" ||
modelOut.confidence >= 0.60 ||
pre.score >= 25
) {
return "MANUAL_REVIEW";
}
return "KEEP_OPEN";
}
const decisions = [];
for (const candidate of candidates) {
const pre = preScore(candidate);
const modelOut = await callGitHubModel(candidate);
const finalDecision = enforcePolicy(modelOut, pre);
decisions.push({
repository: candidate.repository,
issue_number: candidate.issue.number,
issue_url: candidate.issue.html_url,
title: candidate.issue.title,
pre_score: pre.score,
final_decision: finalDecision,
model: modelOut
});
}
await fs.writeFile("decisions.json", JSON.stringify(decisions, null, 2));

View File

@@ -0,0 +1,50 @@
name: issue-resolution-triage
on:
workflow_dispatch:
inputs:
dry_run:
description: "If true, do not close issues"
required: false
default: "true"
max_issues:
description: "How many issues to process"
required: false
default: "100"
schedule:
- cron: "17 2 * * *"
permissions:
contents: read
issues: write
pull-requests: read
models: read
jobs:
triage:
runs-on: ubuntu-latest
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DRY_RUN: ${{ inputs.dry_run || 'true' }}
MAX_ISSUES: ${{ inputs.max_issues || '100' }}
REPO: ${{ github.repository }}
PROJECT_ID: ${{ vars.ISSUE_REVIEW_PROJECT_ID }}
PROJECT_STATUS_FIELD_ID: ${{ vars.PROJECT_STATUS_FIELD_ID }}
PROJECT_CONFIDENCE_FIELD_ID: ${{ vars.PROJECT_CONFIDENCE_FIELD_ID }}
PROJECT_REASON_FIELD_ID: ${{ vars.PROJECT_REASON_FIELD_ID }}
PROJECT_EVIDENCE_FIELD_ID: ${{ vars.PROJECT_EVIDENCE_FIELD_ID }}
PROJECT_LINKED_PR_FIELD_ID: ${{ vars.PROJECT_LINKED_PR_FIELD_ID }}
PROJECT_REPO_FIELD_ID: ${{ vars.PROJECT_REPO_FIELD_ID }}
PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID: ${{ vars.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: "20"
- run: npm ci
- run: node scripts/fetch-candidates.mjs
- run: node scripts/classify-candidates.mjs
- run: node scripts/apply-decisions.mjs

View File

@@ -19,7 +19,6 @@ import (
"google.golang.org/grpc/keepalive"
cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
@@ -31,11 +30,9 @@ import (
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/settingoverrider"
"github.com/netbirdio/netbird/util/crypt"
)
@@ -75,23 +72,6 @@ func (s *BaseServer) CacheStore() cachestore.StoreInterface {
})
}
// SettingOverrider returns a shared setting overrider backed by Redis.
// Returns a no-op overrider if no Redis address is configured.
func (s *BaseServer) SettingOverrider() *settingoverrider.Overrider {
return Create(s, func() *settingoverrider.Overrider {
redisAddr := nbcache.GetAddrFromEnv()
if redisAddr == "" {
return settingoverrider.NewNoop()
}
o, err := settingoverrider.New(context.Background(), redisAddr)
if err != nil {
log.Fatalf("failed to create setting overrider: %v", err)
}
return o
})
}
func (s *BaseServer) Store() store.Store {
return Create(s, func() store.Store {
store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false)
@@ -129,7 +109,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -137,15 +117,6 @@ func (s *BaseServer) APIHandler() http.Handler {
})
}
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
return Create(s, func() *middleware.APIRateLimiter {
cfg, enabled := middleware.RateLimiterConfigFromEnv()
limiter := middleware.NewAPIRateLimiter(cfg)
limiter.SetEnabled(enabled)
return limiter
})
}
func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server {
trustedPeers := s.Config.ReverseProxy.TrustedPeers

View File

@@ -23,7 +23,6 @@ import (
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/settingoverrider"
"github.com/netbirdio/netbird/util/wsproxy"
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
"github.com/netbirdio/netbird/version"
@@ -124,15 +123,6 @@ func (s *BaseServer) Start(ctx context.Context) error {
s.PeersManager()
s.GeoLocationManager()
s.SettingOverrider().Poll(settingoverrider.DefaultInterval, "managementLogLevel", func(value string) error {
level, err := log.ParseLevel(value)
if err != nil {
return fmt.Errorf("parsing log level %q: %w", value, err)
}
log.SetLevel(level)
return nil
})
err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics")
if err != nil {
return fmt.Errorf("failed to expose metrics: %v", err)
@@ -245,7 +235,6 @@ func (s *BaseServer) Stop() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.SettingOverrider().Close()
s.IntegratedValidator().Stop(ctx)
if s.GeoLocationManager() != nil {
_ = s.GeoLocationManager().Stop()

View File

@@ -5,6 +5,9 @@ import (
"fmt"
"net/http"
"net/netip"
"os"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/rs/cors"
@@ -63,11 +66,14 @@ import (
)
const (
apiPrefix = "/api"
apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -88,10 +94,34 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
if rateLimiter == nil {
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
rateLimiter = middleware.NewAPIRateLimiter(nil)
rateLimiter.SetEnabled(false)
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
rpm := 6
if v := os.Getenv(rateLimitingRPMKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
} else {
rpm = value
}
}
burst := 500
if v := os.Getenv(rateLimitingBurstKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
} else {
burst = value
}
}
rateLimitingConfig = &middleware.RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}
}
authMiddleware := middleware.NewAuthMiddleware(
@@ -99,7 +129,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
rateLimiter,
rateLimitingConfig,
appMetrics.GetMeter(),
)

View File

@@ -43,9 +43,14 @@ func NewAuthMiddleware(
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiter *APIRateLimiter,
rateLimiterConfig *RateLimiterConfig,
meter metric.Meter,
) *AuthMiddleware {
var rateLimiter *APIRateLimiter
if rateLimiterConfig != nil {
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
}
var patUsageTracker *PATUsageTracker
if meter != nil {
var err error
@@ -176,8 +181,10 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
m.patUsageTracker.IncrementUsage(token)
}
if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
return status.Errorf(status.TooManyRequests, "too many requests")
if m.rateLimiter != nil && !isTerraformRequest(r) {
if !m.rateLimiter.Allow(token) {
return status.Errorf(status.TooManyRequests, "too many requests")
}
}
ctx := r.Context()

View File

@@ -196,8 +196,6 @@ func TestAuthMiddleware_Handler(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -209,7 +207,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
disabledLimiter,
nil,
nil,
)
@@ -268,7 +266,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
NewAPIRateLimiter(rateLimitConfig),
rateLimitConfig,
nil,
)
@@ -320,7 +318,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
NewAPIRateLimiter(rateLimitConfig),
rateLimitConfig,
nil,
)
@@ -363,7 +361,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
NewAPIRateLimiter(rateLimitConfig),
rateLimitConfig,
nil,
)
@@ -407,7 +405,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
NewAPIRateLimiter(rateLimitConfig),
rateLimitConfig,
nil,
)
@@ -471,7 +469,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
NewAPIRateLimiter(rateLimitConfig),
rateLimitConfig,
nil,
)
@@ -530,7 +528,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
NewAPIRateLimiter(rateLimitConfig),
rateLimitConfig,
nil,
)
@@ -585,7 +583,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
NewAPIRateLimiter(rateLimitConfig),
rateLimitConfig,
nil,
)
@@ -672,8 +670,6 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -685,7 +681,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
disabledLimiter,
nil,
nil,
)

View File

@@ -4,27 +4,14 @@ import (
"context"
"net"
"net/http"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/time/rate"
"github.com/netbirdio/netbird/shared/management/http/util"
)
const (
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
defaultAPIRPM = 6
defaultAPIBurst = 500
)
// RateLimiterConfig holds configuration for the API rate limiter
type RateLimiterConfig struct {
// RequestsPerMinute defines the rate at which tokens are replenished
@@ -47,43 +34,6 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
}
}
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
rpm := defaultAPIRPM
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
} else {
rpm = value
}
}
if rpm <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
rpm = defaultAPIRPM
}
burst := defaultAPIBurst
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
} else {
burst = value
}
}
if burst <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
burst = defaultAPIBurst
}
return &RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}, os.Getenv(RateLimitingEnabledEnv) == "true"
}
// limiterEntry holds a rate limiter and its last access time
type limiterEntry struct {
limiter *rate.Limiter
@@ -96,7 +46,6 @@ type APIRateLimiter struct {
limiters map[string]*limiterEntry
mu sync.RWMutex
stopChan chan struct{}
enabled atomic.Bool
}
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
@@ -110,53 +59,14 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
limiters: make(map[string]*limiterEntry),
stopChan: make(chan struct{}),
}
rl.enabled.Store(true)
go rl.cleanupLoop()
return rl
}
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
rl.enabled.Store(enabled)
}
func (rl *APIRateLimiter) Enabled() bool {
return rl.enabled.Load()
}
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
if config == nil {
return
}
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
return
}
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
newBurst := config.Burst
rl.mu.Lock()
rl.config.RequestsPerMinute = config.RequestsPerMinute
rl.config.Burst = newBurst
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
for _, entry := range rl.limiters {
snapshot = append(snapshot, entry.limiter)
}
rl.mu.Unlock()
for _, l := range snapshot {
l.SetLimit(newRPS)
l.SetBurst(newBurst)
}
}
// Allow checks if a request for the given key (token) is allowed
func (rl *APIRateLimiter) Allow(key string) bool {
if !rl.enabled.Load() {
return true
}
limiter := rl.getLimiter(key)
return limiter.Allow()
}
@@ -164,9 +74,6 @@ func (rl *APIRateLimiter) Allow(key string) bool {
// Wait blocks until the rate limiter allows another request for the given key
// Returns an error if the context is canceled
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
if !rl.enabled.Load() {
return nil
}
limiter := rl.getLimiter(key)
return limiter.Wait(ctx)
}
@@ -246,10 +153,6 @@ func (rl *APIRateLimiter) Reset(key string) {
// Returns 429 Too Many Requests if the rate limit is exceeded.
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !rl.enabled.Load() {
next.ServeHTTP(w, r)
return
}
clientIP := getClientIP(r)
if !rl.Allow(clientIP) {
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)

View File

@@ -1,10 +1,8 @@
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
@@ -158,172 +156,3 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
// Should be allowed again
assert.True(t, rl.Allow("test-key"))
}
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("key"))
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
rl.SetEnabled(false)
assert.False(t, rl.Enabled())
for i := 0; i < 5; i++ {
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
}
rl.SetEnabled(true)
assert.True(t, rl.Enabled())
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
}
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k1"))
assert.True(t, rl.Allow("k1"))
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
// New burst applies to existing keys in place; bucket refills up to new burst over time,
// but importantly newly-added keys use the updated config immediately.
assert.True(t, rl.Allow("k2"))
for i := 0; i < 9; i++ {
assert.True(t, rl.Allow("k2"))
}
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
}
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
rl.UpdateConfig(nil) // must not panic or zero the config
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
}
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.Reset("k")
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
}
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 600,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
var wg sync.WaitGroup
stop := make(chan struct{})
for i := 0; i < 8; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
key := fmt.Sprintf("k%d", id)
for {
select {
case <-stop:
return
default:
rl.Allow(key)
}
}
}(i)
}
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 200; i++ {
select {
case <-stop:
return
default:
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: float64(30 + (i % 90)),
Burst: 1 + (i % 20),
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
rl.SetEnabled(i%2 == 0)
}
}
}()
time.Sleep(100 * time.Millisecond)
close(stop)
wg.Wait()
}
func TestRateLimiterConfigFromEnv(t *testing.T) {
t.Setenv(RateLimitingEnabledEnv, "true")
t.Setenv(RateLimitingRPMEnv, "42")
t.Setenv(RateLimitingBurstEnv, "7")
cfg, enabled := RateLimiterConfigFromEnv()
assert.True(t, enabled)
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
assert.Equal(t, 7, cfg.Burst)
t.Setenv(RateLimitingEnabledEnv, "false")
_, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
t.Setenv(RateLimitingEnabledEnv, "")
t.Setenv(RateLimitingRPMEnv, "")
t.Setenv(RateLimitingBurstEnv, "")
cfg, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
assert.Equal(t, defaultAPIBurst, cfg.Burst)
t.Setenv(RateLimitingRPMEnv, "0")
t.Setenv(RateLimitingBurstEnv, "-5")
cfg, _ = RateLimiterConfigFromEnv()
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
}

View File

@@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -1,120 +0,0 @@
package settingoverrider
import (
"context"
"errors"
"fmt"
"time"
"github.com/redis/go-redis/v9"
log "github.com/sirupsen/logrus"
)
const (
DefaultInterval = 5 * time.Minute
)
// ApplyFunc is called with the raw Redis string value whenever it changes.
// The function is responsible for parsing and applying the value.
// Return an error to log a warning without stopping the polling loop.
type ApplyFunc func(value string) error
// Overrider holds a shared Redis connection and allows registering
// individual settings that are polled independently.
type Overrider struct {
client *redis.Client
cancel context.CancelFunc
ctx context.Context
noop bool
}
// New creates an Overrider by connecting to Redis at the given address.
// The address should follow the Redis URL format (e.g. "redis://localhost:6379").
func New(ctx context.Context, redisAddr string) (*Overrider, error) {
if redisAddr == "" {
return nil, fmt.Errorf("redis address is empty")
}
options, err := redis.ParseURL(redisAddr)
if err != nil {
return nil, fmt.Errorf("parsing redis address: %w", err)
}
client := redis.NewClient(options)
pingCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
if _, err := client.Ping(pingCtx).Result(); err != nil {
_ = client.Close()
return nil, fmt.Errorf("connecting to redis: %w", err)
}
oCtx, oCancel := context.WithCancel(ctx)
return &Overrider{client: client, cancel: oCancel, ctx: oCtx}, nil
}
// NewNoop returns an Overrider that does nothing.
// Poll calls are silently ignored and Close is a no-op.
func NewNoop() *Overrider {
return &Overrider{noop: true}
}
// Close stops all polling goroutines and closes the underlying Redis client.
func (o *Overrider) Close() error {
if o.noop {
return nil
}
o.cancel()
return o.client.Close()
}
// Poll starts a background goroutine that polls a single Redis key at the given interval
// and calls apply whenever the value changes. The goroutine stops when the Overrider is closed.
func (o *Overrider) Poll(interval time.Duration, redisKey string, apply ApplyFunc) {
if o.noop {
return
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
var lastSeen *string
for {
select {
case <-o.ctx.Done():
log.WithContext(o.ctx).Infof("Stopping settings overrider for key %q", redisKey)
return
case <-ticker.C:
getCtx, cancel := context.WithTimeout(o.ctx, 5*time.Second)
val, err := o.client.Get(getCtx, redisKey).Result()
cancel()
if errors.Is(err, redis.Nil) || val == "" {
continue
}
if err != nil {
if o.ctx.Err() != nil {
return
}
log.WithContext(o.ctx).Errorf("Unable to get setting %q from Redis: %v", redisKey, err)
continue
}
if lastSeen != nil && *lastSeen == val {
continue
}
if err := apply(val); err != nil {
log.WithContext(o.ctx).Warnf("Failed to apply setting %q with value %q: %v", redisKey, val, err)
continue
}
lastSeen = &val
}
}
}()
}

View File

@@ -1,111 +0,0 @@
package settingoverrider
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis"
"github.com/testcontainers/testcontainers-go/wait"
)
func TestPoll_AppliesSettingFromRedis(t *testing.T) {
o, client := setupOverrider(t)
key := "test-setting-key"
require.NoError(t, client.Set(context.Background(), key, "hello", 0).Err())
var applied atomic.Value
o.Poll(100*time.Millisecond, key, func(value string) error {
applied.Store(value)
return nil
})
assert.Eventually(t, func() bool {
v := applied.Load()
return v != nil && v.(string) == "hello"
}, 5*time.Second, 50*time.Millisecond)
}
func TestPoll_IndependentSettings(t *testing.T) {
o, client := setupOverrider(t)
require.NoError(t, client.Set(context.Background(), "key-a", "val-a", 0).Err())
require.NoError(t, client.Set(context.Background(), "key-b", "val-b", 0).Err())
var gotA, gotB atomic.Value
o.Poll(100*time.Millisecond, "key-a", func(v string) error { gotA.Store(v); return nil })
o.Poll(100*time.Millisecond, "key-b", func(v string) error { gotB.Store(v); return nil })
assert.Eventually(t, func() bool {
a, b := gotA.Load(), gotB.Load()
return a != nil && a.(string) == "val-a" && b != nil && b.(string) == "val-b"
}, 5*time.Second, 50*time.Millisecond)
}
func TestPoll_SkipsDuplicateValues(t *testing.T) {
o, client := setupOverrider(t)
key := "test-dedup"
require.NoError(t, client.Set(context.Background(), key, "same", 0).Err())
var count atomic.Int32
o.Poll(100*time.Millisecond, key, func(string) error {
count.Add(1)
return nil
})
// wait for a few ticks
time.Sleep(600 * time.Millisecond)
assert.Equal(t, int32(1), count.Load(), "Apply should be called only once for unchanged value")
}
func setupOverrider(t *testing.T) (*Overrider, *redis.Client) {
t.Helper()
ctx := context.Background()
redisContainer, err := testcontainersredis.RunContainer(ctx,
testcontainers.WithImage("redis:7"),
testcontainers.WithWaitStrategy(
wait.ForListeningPort("6379/tcp"),
),
)
require.NoError(t, err, "Failed to create redis test container")
t.Cleanup(func() {
if err := redisContainer.Terminate(ctx); err != nil {
t.Logf("failed to terminate redis container: %s", err)
}
})
redisURL, err := redisContainer.ConnectionString(ctx)
require.NoError(t, err)
o, err := New(ctx, redisURL)
require.NoError(t, err)
t.Cleanup(func() {
if err := o.Close(); err != nil {
t.Logf("failed to close overrider: %s", err)
}
})
// separate client for test setup (setting keys)
options, err := redis.ParseURL(redisURL)
require.NoError(t, err)
client := redis.NewClient(options)
t.Cleanup(func() {
if err := client.Close(); err != nil {
t.Logf("failed to close redis client: %s", err)
}
})
return o, client
}

View File

@@ -18,9 +18,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/settingoverrider"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/signal/proto"
@@ -116,24 +114,7 @@ var (
}
}()
overrider := settingoverrider.NewNoop()
if redisAddr := cache.GetAddrFromEnv(); redisAddr != "" {
overrider, err = settingoverrider.New(cmd.Context(), redisAddr)
if err != nil {
return fmt.Errorf("failed to create setting overrider: %w", err)
}
defer func() { _ = overrider.Close() }()
}
overrider.Poll(settingoverrider.DefaultInterval, "signalLogLevel", func(value string) error {
level, err := log.ParseLevel(value)
if err != nil {
return fmt.Errorf("parsing log level %q: %w", value, err)
}
log.SetLevel(level)
return nil
})
srv, err := server.NewServer(cmd.Context(), metricsServer.Meter, overrider)
srv, err := server.NewServer(cmd.Context(), metricsServer.Meter)
if err != nil {
return fmt.Errorf("creating signal server: %v", err)
}

View File

@@ -1,135 +0,0 @@
package server
import (
"context"
"math"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultSendRateLogInterval = 5 * time.Minute
defaultSendRateTopPercent = 0.95
envSendRateLogInterval = "NB_SIGNAL_SEND_RATE_LOG_INTERVAL"
envSendRateTopPercent = "NB_SIGNAL_SEND_RATE_LOG_TOP_PERCENT"
)
// sendRateTracker tracks per-key message counts and logs the busiest peers periodically.
type sendRateTracker struct {
mu sync.Mutex
counts map[string]int64
// atomic so they can be updated by the setting overrider without locking
intervalNs atomic.Int64
// topPercent stored as float64 bits for atomic access
topPercentBits atomic.Uint64
}
func newSendRateTracker() *sendRateTracker {
interval := defaultSendRateLogInterval
if v := os.Getenv(envSendRateLogInterval); v != "" {
if parsed, err := time.ParseDuration(v); err == nil && parsed > 0 {
interval = parsed
}
}
topPercent := defaultSendRateTopPercent
if v := os.Getenv(envSendRateTopPercent); v != "" {
if parsed, err := strconv.ParseFloat(v, 64); err == nil && parsed > 0 && parsed <= 1 {
topPercent = parsed
}
}
log.Debugf("send rate tracker: interval=%s, top_percent=%.2f", interval, topPercent)
t := &sendRateTracker{
counts: make(map[string]int64),
}
t.intervalNs.Store(int64(interval))
t.topPercentBits.Store(math.Float64bits(topPercent))
return t
}
func (t *sendRateTracker) getInterval() time.Duration {
return time.Duration(t.intervalNs.Load())
}
func (t *sendRateTracker) setInterval(d time.Duration) {
t.intervalNs.Store(int64(d))
}
func (t *sendRateTracker) getTopPercent() float64 {
return math.Float64frombits(t.topPercentBits.Load())
}
func (t *sendRateTracker) setTopPercent(p float64) {
t.topPercentBits.Store(math.Float64bits(p))
}
func (t *sendRateTracker) increment(key string) {
t.mu.Lock()
t.counts[key]++
t.mu.Unlock()
}
// resetAndSnapshot atomically returns current counts and resets the tracker.
func (t *sendRateTracker) resetAndSnapshot() map[string]int64 {
t.mu.Lock()
snap := t.counts
t.counts = make(map[string]int64, len(snap))
t.mu.Unlock()
return snap
}
// logSendRates periodically logs peers in the top percentile of the busiest peer.
func (t *sendRateTracker) logSendRates(ctx context.Context) {
currentInterval := t.getInterval()
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if newInterval := t.getInterval(); newInterval != currentInterval {
currentInterval = newInterval
ticker.Reset(currentInterval)
}
snap := t.resetAndSnapshot()
if len(snap) == 0 {
continue
}
var maxCount int64
for _, count := range snap {
if count > maxCount {
maxCount = count
}
}
topPercent := t.getTopPercent()
threshold := int64(float64(maxCount) * topPercent)
intervalMin := currentInterval.Minutes()
log.Debugf("send rate stats: %d unique peers in last %.0fs, max rate %.1f msg/min",
len(snap), currentInterval.Seconds(), float64(maxCount)/intervalMin)
logged := 0
for key, count := range snap {
if count >= threshold {
log.Debugf("peer [%s] %.1f msg/min", key, float64(count)/intervalMin)
logged++
if logged >= 100 {
break
}
}
}
}
}
}

View File

@@ -1,56 +0,0 @@
package server
import (
"sync"
"testing"
)
func TestSendRateTracker_Increment(t *testing.T) {
tracker := newSendRateTracker()
tracker.increment("peer-a")
tracker.increment("peer-a")
tracker.increment("peer-b")
snap := tracker.resetAndSnapshot()
if snap["peer-a"] != 2 {
t.Errorf("expected peer-a count 2, got %d", snap["peer-a"])
}
if snap["peer-b"] != 1 {
t.Errorf("expected peer-b count 1, got %d", snap["peer-b"])
}
}
func TestSendRateTracker_ResetAndSnapshot_Resets(t *testing.T) {
tracker := newSendRateTracker()
tracker.increment("peer-a")
snap1 := tracker.resetAndSnapshot()
if snap1["peer-a"] != 1 {
t.Fatalf("expected 1, got %d", snap1["peer-a"])
}
snap2 := tracker.resetAndSnapshot()
if len(snap2) != 0 {
t.Errorf("expected empty snapshot after reset, got %v", snap2)
}
}
func TestSendRateTracker_ConcurrentIncrement(t *testing.T) {
tracker := newSendRateTracker()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tracker.increment("peer-x")
}()
}
wg.Wait()
snap := tracker.resetAndSnapshot()
if snap["peer-x"] != 100 {
t.Errorf("expected 100, got %d", snap["peer-x"])
}
}

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
@@ -18,7 +17,6 @@ import (
"github.com/netbirdio/signal-dispatcher/dispatcher"
"github.com/netbirdio/netbird/shared/settingoverrider"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/signal/peer"
@@ -61,12 +59,10 @@ type Server struct {
successHeader metadata.MD
sendTimeout time.Duration
sendTracker *sendRateTracker
}
// NewServer creates a new Signal server
func NewServer(ctx context.Context, meter metric.Meter, overrider *settingoverrider.Overrider, metricsPrefix ...string) (*Server, error) {
func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string) (*Server, error) {
appMetrics, err := metrics.NewAppMetrics(meter, metricsPrefix...)
if err != nil {
return nil, fmt.Errorf("creating app metrics: %v", err)
@@ -84,36 +80,14 @@ func NewServer(ctx context.Context, meter metric.Meter, overrider *settingoverri
sTimeout = parsed
}
tracker := newSendRateTracker()
s := &Server{
dispatcher: d,
registry: peer.NewRegistry(appMetrics),
metrics: appMetrics,
successHeader: metadata.Pairs(proto.HeaderRegistered, "1"),
sendTimeout: sTimeout,
sendTracker: tracker,
}
overrider.Poll(settingoverrider.DefaultInterval, "signalSendRateLogInterval", func(value string) error {
parsed, err := time.ParseDuration(value)
if err != nil || parsed <= 0 {
return fmt.Errorf("invalid send rate log interval %q: %w", value, err)
}
tracker.setInterval(parsed)
return nil
})
overrider.Poll(settingoverrider.DefaultInterval, "signalSendRateTopPercent", func(value string) error {
parsed, err := strconv.ParseFloat(value, 64)
if err != nil || parsed <= 0 || parsed > 1 {
return fmt.Errorf("invalid send rate top percent %q: %w", value, err)
}
tracker.setTopPercent(parsed)
return nil
})
go tracker.logSendRates(ctx)
return s, nil
}
@@ -121,8 +95,6 @@ func NewServer(ctx context.Context, meter metric.Meter, overrider *settingoverri
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
s.sendTracker.increment(msg.Key)
if _, found := s.registry.Get(msg.RemoteKey); found {
s.forwardMessageToPeer(ctx, msg)
return &proto.EncryptedMessage{}, nil