Compare commits

..

1 Commits

Author SHA1 Message Date
Viktor Liu
e81ce81494 Bound embed client WireGuard per-Device memory 2026-04-22 13:04:40 +02:00
27 changed files with 524 additions and 1658 deletions

View File

@@ -1,26 +0,0 @@
You are a GitHub issue resolution classifier.
Your job is to decide whether an open GitHub issue is:
- AUTO_CLOSE
- MANUAL_REVIEW
- KEEP_OPEN
Rules:
1. AUTO_CLOSE is only allowed if there is objective, hard evidence:
- a merged linked PR that clearly resolves the issue, or
- an explicit maintainer/member/owner/collaborator comment saying the issue is fixed, resolved, duplicate, or superseded
2. If there is any contradictory later evidence, do NOT AUTO_CLOSE.
3. If evidence is promising but not airtight, choose MANUAL_REVIEW.
4. If the issue still appears active or unresolved, choose KEEP_OPEN.
5. Do not invent evidence.
6. Output valid JSON only.
Maintainer-authoritative roles:
- MEMBER
- OWNER
- COLLABORATOR
Important:
- Later comments outweigh earlier ones.
- A non-maintainer saying "fixed for me" is not enough for AUTO_CLOSE.
- If uncertain, prefer MANUAL_REVIEW or KEEP_OPEN.

View File

@@ -1,78 +0,0 @@
{
"type": "object",
"additionalProperties": false,
"required": [
"decision",
"reason_code",
"confidence",
"hard_signals",
"contradictions",
"summary",
"close_comment",
"manual_review_note"
],
"properties": {
"decision": {
"type": "string",
"enum": ["AUTO_CLOSE", "MANUAL_REVIEW", "KEEP_OPEN"]
},
"reason_code": {
"type": "string",
"enum": [
"resolved_by_merged_pr",
"maintainer_confirmed_resolved",
"duplicate_confirmed",
"superseded_confirmed",
"likely_fixed_but_unconfirmed",
"still_open",
"unclear"
]
},
"confidence": {
"type": "number",
"minimum": 0,
"maximum": 1
},
"hard_signals": {
"type": "array",
"items": {
"type": "object",
"required": ["type", "url"],
"properties": {
"type": {
"type": "string",
"enum": [
"merged_pr",
"maintainer_comment",
"duplicate_reference",
"superseded_reference"
]
},
"url": { "type": "string" }
}
}
},
"contradictions": {
"type": "array",
"items": {
"type": "object",
"required": ["type", "url"],
"properties": {
"type": {
"type": "string",
"enum": [
"reporter_still_broken",
"later_unresolved_comment",
"ambiguous_pr_link",
"other"
]
},
"url": { "type": "string" }
}
}
},
"summary": { "type": "string" },
"close_comment": { "type": "string" },
"manual_review_note": { "type": "string" }
}
}

View File

@@ -1,152 +0,0 @@
import fs from "node:fs/promises";
const decisions = JSON.parse(await fs.readFile("decisions.json", "utf8"));
const dryRun = String(process.env.DRY_RUN).toLowerCase() === "true";
const headers = {
Authorization: `Bearer ${process.env.GH_TOKEN}`,
Accept: "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
};
async function rest(url, method = "GET", body) {
const res = await fetch(url, {
method,
headers,
body: body ? JSON.stringify(body) : undefined
});
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
return res.status === 204 ? null : res.json();
}
async function graphql(query, variables) {
const res = await fetch("https://api.github.com/graphql", {
method: "POST",
headers,
body: JSON.stringify({ query, variables })
});
if (!res.ok) throw new Error(`${res.status}: ${await res.text()}`);
const json = await res.json();
if (json.errors) throw new Error(JSON.stringify(json.errors));
return json.data;
}
async function addLabel(owner, repo, issueNumber, labels) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/labels`,
"POST",
{ labels }
);
}
async function addComment(owner, repo, issueNumber, body) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/comments`,
"POST",
{ body }
);
}
async function closeIssue(owner, repo, issueNumber) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`,
"PATCH",
{ state: "closed", state_reason: "completed" }
);
}
async function getIssueNodeId(owner, repo, issueNumber) {
const issue = await rest(`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`);
return issue.node_id;
}
async function addToProject(issueNodeId) {
const mutation = `
mutation($projectId: ID!, $contentId: ID!) {
addProjectV2ItemById(input: {projectId: $projectId, contentId: $contentId}) {
item { id }
}
}
`;
const data = await graphql(mutation, {
projectId: process.env.PROJECT_ID,
contentId: issueNodeId
});
return data.addProjectV2ItemById.item.id;
}
async function setTextField(itemId, fieldId, value) {
const mutation = `
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
updateProjectV2ItemFieldValue(input: {
projectId: $projectId,
itemId: $itemId,
fieldId: $fieldId,
value: { text: $value }
}) {
projectV2Item { id }
}
}
`;
return graphql(mutation, {
projectId: process.env.PROJECT_ID,
itemId,
fieldId,
value
});
}
for (const d of decisions) {
const [owner, repo] = d.repository.split("/");
if (d.final_decision === "AUTO_CLOSE") {
if (dryRun) continue;
await addLabel(owner, repo, d.issue_number, ["auto-closed-resolved"]);
await addComment(
owner,
repo,
d.issue_number,
d.model.close_comment ||
"This appears resolved based on linked evidence, so were closing it automatically. Reply if this still reproduces and well reopen."
);
await closeIssue(owner, repo, d.issue_number);
}
if (d.final_decision === "MANUAL_REVIEW") {
await addLabel(owner, repo, d.issue_number, ["resolution-candidate"]);
const issueNodeId = await getIssueNodeId(owner, repo, d.issue_number);
const itemId = await addToProject(issueNodeId);
if (process.env.PROJECT_CONFIDENCE_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_CONFIDENCE_FIELD_ID, String(d.model.confidence));
}
if (process.env.PROJECT_REASON_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_REASON_FIELD_ID, d.model.reason_code);
}
if (process.env.PROJECT_EVIDENCE_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_EVIDENCE_FIELD_ID, d.issue_url);
}
if (process.env.PROJECT_LINKED_PR_FIELD_ID) {
const linked = (d.model.hard_signals || []).map(x => x.url).join(", ");
if (linked) {
await setTextField(itemId, process.env.PROJECT_LINKED_PR_FIELD_ID, linked);
}
}
if (process.env.PROJECT_REPO_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_REPO_FIELD_ID, d.repository);
}
await addComment(
owner,
repo,
d.issue_number,
d.model.manual_review_note ||
"This issue looks like a possible resolution candidate, but not with enough certainty for automatic closure. Added to the review queue."
);
}
}

View File

@@ -1,125 +0,0 @@
import fs from "node:fs/promises";
const candidates = JSON.parse(await fs.readFile("candidates.json", "utf8"));
function isMaintainerRole(role) {
return ["MEMBER", "OWNER", "COLLABORATOR"].includes(role || "");
}
function preScore(candidate) {
let score = 0;
const hardSignals = [];
const contradictions = [];
for (const t of candidate.timeline) {
const sourceIssue = t.source?.issue;
if (t.event === "cross-referenced" && sourceIssue?.pull_request?.html_url) {
hardSignals.push({
type: "merged_pr",
url: sourceIssue.html_url
});
score += 40; // provisional until PR merged state is verified
}
if (["referenced", "connected"].includes(t.event)) {
score += 10;
}
}
for (const c of candidate.comments) {
const body = c.body.toLowerCase();
if (
isMaintainerRole(c.author_association) &&
/\b(fixed|resolved|duplicate|superseded|closing)\b/.test(body)
) {
score += 25;
hardSignals.push({
type: "maintainer_comment",
url: c.html_url
});
}
if (/\b(still broken|still happening|not fixed|reproducible)\b/.test(body)) {
score -= 50;
contradictions.push({
type: "later_unresolved_comment",
url: c.html_url
});
}
}
return { score, hardSignals, contradictions };
}
async function callGitHubModel(issuePacket) {
// Replace this stub with the GitHub Models inference call used by your org.
// The workflow already has models: read permission.
return {
decision: "MANUAL_REVIEW",
reason_code: "likely_fixed_but_unconfirmed",
confidence: 0.74,
hard_signals: [],
contradictions: [],
summary: "Potential resolution candidate; evidence is not strong enough to close automatically.",
close_comment: "This appears resolved, so were closing it automatically. Reply if this is still reproducible.",
manual_review_note: "Potential resolution candidate. Please review evidence before closing."
};
}
function enforcePolicy(modelOut, pre) {
const approvedReasons = new Set([
"resolved_by_merged_pr",
"maintainer_confirmed_resolved",
"duplicate_confirmed",
"superseded_confirmed"
]);
const hasHardSignal =
(modelOut.hard_signals || []).some(s =>
["merged_pr", "maintainer_comment", "duplicate_reference", "superseded_reference"].includes(s.type)
) || pre.hardSignals.length > 0;
const hasContradiction =
(modelOut.contradictions || []).length > 0 || pre.contradictions.length > 0;
if (
modelOut.decision === "AUTO_CLOSE" &&
modelOut.confidence >= 0.97 &&
approvedReasons.has(modelOut.reason_code) &&
hasHardSignal &&
!hasContradiction
) {
return "AUTO_CLOSE";
}
if (
modelOut.decision === "MANUAL_REVIEW" ||
modelOut.confidence >= 0.60 ||
pre.score >= 25
) {
return "MANUAL_REVIEW";
}
return "KEEP_OPEN";
}
const decisions = [];
for (const candidate of candidates) {
const pre = preScore(candidate);
const modelOut = await callGitHubModel(candidate);
const finalDecision = enforcePolicy(modelOut, pre);
decisions.push({
repository: candidate.repository,
issue_number: candidate.issue.number,
issue_url: candidate.issue.html_url,
title: candidate.issue.title,
pre_score: pre.score,
final_decision: finalDecision,
model: modelOut
});
}
await fs.writeFile("decisions.json", JSON.stringify(decisions, null, 2));

View File

@@ -1,50 +0,0 @@
name: issue-resolution-triage
on:
workflow_dispatch:
inputs:
dry_run:
description: "If true, do not close issues"
required: false
default: "true"
max_issues:
description: "How many issues to process"
required: false
default: "100"
schedule:
- cron: "17 2 * * *"
permissions:
contents: read
issues: write
pull-requests: read
models: read
jobs:
triage:
runs-on: ubuntu-latest
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DRY_RUN: ${{ inputs.dry_run || 'true' }}
MAX_ISSUES: ${{ inputs.max_issues || '100' }}
REPO: ${{ github.repository }}
PROJECT_ID: ${{ vars.ISSUE_REVIEW_PROJECT_ID }}
PROJECT_STATUS_FIELD_ID: ${{ vars.PROJECT_STATUS_FIELD_ID }}
PROJECT_CONFIDENCE_FIELD_ID: ${{ vars.PROJECT_CONFIDENCE_FIELD_ID }}
PROJECT_REASON_FIELD_ID: ${{ vars.PROJECT_REASON_FIELD_ID }}
PROJECT_EVIDENCE_FIELD_ID: ${{ vars.PROJECT_EVIDENCE_FIELD_ID }}
PROJECT_LINKED_PR_FIELD_ID: ${{ vars.PROJECT_LINKED_PR_FIELD_ID }}
PROJECT_REPO_FIELD_ID: ${{ vars.PROJECT_REPO_FIELD_ID }}
PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID: ${{ vars.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: "20"
- run: npm ci
- run: node scripts/fetch-candidates.mjs
- run: node scripts/classify-candidates.mjs
- run: node scripts/apply-decisions.mjs

View File

@@ -12,6 +12,7 @@ import (
"sync"
"github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
@@ -469,6 +470,55 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// WGTuning bundles runtime-adjustable WireGuard knobs exposed by the embed
// client. Nil fields are left unchanged; set a non-nil pointer to apply.
type WGTuning struct {
// PreallocatedBuffersPerPool caps each per-Device WaitPool.
// Zero means "unbounded" (no cap). Live-tunable only if the underlying
// Device was originally created with a nonzero cap.
PreallocatedBuffersPerPool *uint32
}
// SetWGTuning applies the given tuning to this client's live Device.
// Startup-only knobs (batch size) must be set via the package-level
// setters before Start.
func (c *Client) SetWGTuning(t WGTuning) error {
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetWGTuning(internal.WGTuning{
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
})
}
// SetWGDefaultPreallocatedBuffersPerPool sets the default WaitPool cap
// applied to Devices created after this call. Zero disables the cap.
// Existing Devices are unaffected; use Client.SetWGTuning for that.
func SetWGDefaultPreallocatedBuffersPerPool(n uint32) {
wgdevice.SetPreallocatedBuffersPerPool(n)
}
// WGDefaultPreallocatedBuffersPerPool returns the current default WaitPool
// cap applied to newly-created Devices.
func WGDefaultPreallocatedBuffersPerPool() uint32 {
return wgdevice.PreallocatedBuffersPerPool
}
// SetWGDefaultMaxBatchSize sets the default per-Device batch size applied
// to Devices created after this call. Zero means "use the bind+tun default"
// (NOT unlimited). Must be called before Start to take effect for a new
// Client.
func SetWGDefaultMaxBatchSize(n uint32) {
wgdevice.SetMaxBatchSizeOverride(n)
}
// WGDefaultMaxBatchSize returns the current default batch-size override.
// Zero means "no override".
func WGDefaultMaxBatchSize() uint32 {
return wgdevice.MaxBatchSizeOverride
}
// getEngine safely retrieves the engine from the client with proper locking.
// Returns ErrClientNotStarted if the client is not started.
// Returns ErrEngineNotStarted if the engine is not available.

View File

@@ -239,12 +239,8 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv6Count++
}
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
// routing-correctness checks above are the real assertions; the counts
// are a sanity bound to catch a totally silent path.
minDelivered := packetsPerFamily * 80 / 100
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
assert.Equal(t, packetsPerFamily, ipv4Count)
assert.Equal(t, packetsPerFamily, ipv6Count)
}
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {

View File

@@ -3,12 +3,10 @@ package debug
import (
"context"
"errors"
"net"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
@@ -21,10 +19,8 @@ func TestUpload(t *testing.T) {
t.Skip("Skipping upload test on docker ci")
}
testDir := t.TempDir()
addr := reserveLoopbackPort(t)
testURL := "http://" + addr
testURL := "http://localhost:8080"
t.Setenv("SERVER_URL", testURL)
t.Setenv("SERVER_ADDRESS", addr)
t.Setenv("STORE_DIR", testDir)
srv := server.NewServer()
go func() {
@@ -37,7 +33,6 @@ func TestUpload(t *testing.T) {
t.Errorf("Failed to stop server: %v", err)
}
})
waitForServer(t, addr)
file := filepath.Join(t.TempDir(), "tmpfile")
fileContent := []byte("test file content")
@@ -52,30 +47,3 @@ func TestUpload(t *testing.T) {
require.NoError(t, err)
require.Equal(t, fileContent, createdFileContent)
}
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
// address, then releases it so the server under test can rebind. The close/
// rebind window is racy in theory; on loopback with a kernel-assigned port
// it's essentially never contended in practice.
func reserveLoopbackPort(t *testing.T) string {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := l.Addr().String()
require.NoError(t, l.Close())
return addr
}
func waitForServer(t *testing.T, addr string) {
t.Helper()
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err == nil {
_ = c.Close()
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("server did not start listening on %s in time", addr)
}

View File

@@ -1,10 +1,7 @@
package dns
import (
"context"
"fmt"
"math"
"net"
"slices"
"strconv"
"strings"
@@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() {
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
c.dispatch(w, r, math.MaxInt)
}
// dispatch routes a DNS request through the chain, skipping handlers with
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
if len(r.Question) == 0 {
return
}
@@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in
// Try handlers in priority order
for _, entry := range handlers {
if entry.Priority > maxPriority {
continue
}
if !c.isHandlerMatch(qname, entry) {
continue
}
@@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
cw.response.Len(), meta, time.Since(startTime))
}
// ResolveInternal runs an in-process DNS query against the chain, skipping any
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
// (bounded by the invoked handler's internal timeout).
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
base := &internalResponseWriter{}
done := make(chan struct{})
go func() {
c.dispatch(base, r, maxPriority)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
// Prefer a completed response if dispatch finished concurrently with cancellation.
select {
case <-done:
default:
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
}
}
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
strings.ToLower(r.Question[0].Name), maxPriority)
}
return base.response, nil
}
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
// priority ≤ maxPriority.
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
c.mu.RLock()
defer c.mu.RUnlock()
for _, h := range c.handlers {
if h.Pattern == "." && h.Priority <= maxPriority {
return true
}
}
return false
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":
@@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
}
}
}
// internalResponseWriter captures a dns.Msg for in-process chain queries.
type internalResponseWriter struct {
response *dns.Msg
}
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
// still surface their answer to ResolveInternal.
func (w *internalResponseWriter) Write(p []byte) (int, error) {
msg := new(dns.Msg)
if err := msg.Unpack(p); err != nil {
return 0, err
}
w.response = msg
return len(p), nil
}
func (w *internalResponseWriter) Close() error { return nil }
func (w *internalResponseWriter) TsigStatus() error { return nil }
// TsigTimersOnly is part of dns.ResponseWriter.
func (w *internalResponseWriter) TsigTimersOnly(bool) {
// no-op: in-process queries carry no TSIG state.
}
// Hijack is part of dns.ResponseWriter.
func (w *internalResponseWriter) Hijack() {
// no-op: in-process queries have no underlying connection to hand off.
}

View File

@@ -1,15 +1,11 @@
package dns_test
import (
"context"
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
@@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
})
}
}
// answeringHandler writes a fixed A record to ack the query. Used to verify
// which handler ResolveInternal dispatches to.
type answeringHandler struct {
name string
ip string
}
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
_ = w.WriteMsg(resp)
}
func (h *answeringHandler) String() string { return h.name }
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, 1, len(resp.Answer))
a, ok := resp.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
}
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.Error(t, err, "no handler at or below maxPriority should error")
}
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
type rawWriteHandler struct {
ip string
}
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
packed, err := resp.Pack()
if err != nil {
return
}
_, _ = w.Write(packed)
}
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
chain := nbdns.NewHandlerChain()
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok)
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
}
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
chain := nbdns.NewHandlerChain()
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
assert.Error(t, err)
}
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
type hangingHandler struct {
block chan struct{}
}
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
<-h.block
resp := &dns.Msg{}
resp.SetReply(r)
_ = w.WriteMsg(resp)
}
func (h *hangingHandler) String() string { return "hangingHandler" }
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &hangingHandler{block: make(chan struct{})}
defer close(h.block)
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
elapsed := time.Since(start)
assert.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
}
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
chain.AddHandler(".", h, nbdns.PriorityDefault)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
chain.RemoveHandler(".", nbdns.PriorityDefault)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
// Primary nsgroup case: root handler lands at PriorityUpstream.
chain.AddHandler(".", h, nbdns.PriorityUpstream)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
chain.RemoveHandler(".", nbdns.PriorityUpstream)
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
chain.AddHandler(".", h, nbdns.PriorityFallback)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
chain.RemoveHandler(".", nbdns.PriorityFallback)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
}

View File

@@ -2,83 +2,40 @@ package mgmt
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"os"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
dnsTimeout = 5 * time.Second
defaultTTL = 300 * time.Second
refreshBackoff = 30 * time.Second
const dnsTimeout = 5 * time.Second
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
)
// ChainResolver lets the cache refresh stale entries through the DNS handler
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
// system resolver.
type ChainResolver interface {
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
HasRootHandlerAtOrBelow(maxPriority int) bool
}
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
// records and cachedAt are set at construction and treated as immutable;
// lastFailedRefresh and consecFailures are mutable and must be accessed under
// Resolver.mutex.
type cachedRecord struct {
records []dns.RR
cachedAt time.Time
lastFailedRefresh time.Time
consecFailures int
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// Resolver caches critical NetBird infrastructure domains
type Resolver struct {
records map[dns.Question]*cachedRecord
records map[dns.Question][]dns.RR
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
}
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
// the OS resolver routed the recursive query back to us (loop). Only
// the OS path arms this so chain-path refreshes don't produce false
// positives. The atomic bool is CAS-flipped once per refresh to
// throttle the warning log.
refreshing map[dns.Question]*atomic.Bool
cacheTTL time.Duration
type ipsResponse struct {
ips []netip.Addr
err error
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question][]dns.RR),
}
}
@@ -87,19 +44,7 @@ func (m *Resolver) String() string {
return "MgmtCacheResolver"
}
// SetChainResolver wires the handler chain used to refresh stale cache entries.
// maxPriority caps which handlers may answer refresh queries (typically
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
// mgmt/route/local handlers are skipped).
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
m.mutex.Lock()
m.chain = chain
m.chainMaxPriority = maxPriority
m.mutex.Unlock()
}
// ServeDNS serves cached A/AAAA records. Stale entries are returned
// immediately and refreshed asynchronously (stale-while-revalidate).
// ServeDNS implements dns.Handler interface.
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
m.continueToNext(w, r)
@@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
m.mutex.RLock()
cached, found := m.records[question]
inflight := m.refreshing[question]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
shouldRefresh = stale && !inBackoff
}
records, found := m.records[question]
m.mutex.RUnlock()
if !found {
@@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
if inflight != nil && inflight.CompareAndSwap(false, true) {
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
question.Name)
}
// Skip scheduling a refresh goroutine if one is already inflight for
// this question; singleflight would dedup anyway but skipping avoids
// a parked goroutine per stale hit under bursty load.
if shouldRefresh && inflight == nil {
m.scheduleRefresh(question, cached)
}
resp := &dns.Msg{}
resp.SetReply(r)
resp.Authoritative = false
resp.RecursionAvailable = true
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
resp.Answer = append(resp.Answer, records...)
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
@@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
}
}
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
// AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
if errA != nil && errAAAA != nil {
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
ips, err := lookupIPWithExtraTimeout(ctx, d)
if err != nil {
return err
}
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
if err := errors.Join(errA, errAAAA); err != nil {
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
var aRecords, aaaaRecords []dns.RR
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
}
now := time.Now()
m.mutex.Lock()
defer m.mutex.Unlock()
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
if len(aRecords) > 0 {
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
return nil
}
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
// untouched on error. Caller holds m.mutex.
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
switch {
case len(records) > 0:
m.records[q] = &cachedRecord{records: records, cachedAt: now}
case err == nil:
delete(m.records, q)
}
}
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
resultChan := make(chan *ipsResponse, 1)
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
// unique in-flight key; bursty stale hits share its channel. expected is the
// cachedRecord pointer observed by the caller; the refresh only mutates the
// cache if that pointer is still the one stored, so a stale in-flight refresh
// can't clobber a newer entry written by AddDomain or a competing refresh.
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
key := question.Name + "|" + dns.TypeToString[question.Qtype]
_ = m.refreshGroup.DoChan(key, func() (any, error) {
return nil, m.refreshQuestion(question, expected)
})
}
// refreshQuestion replaces the cached records on success, or marks the entry
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
// a resolver loop by spotting a query for this same question arriving on us.
// expected pins the cache entry observed at schedule time; mutations only apply
// if m.records[question] still points at it.
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
if err != nil {
m.markRefreshFailed(question, expected)
return fmt.Errorf("parse domain: %w", err)
}
records, err := m.lookupRecords(ctx, d, question)
if err != nil {
fails := m.markRefreshFailed(question, expected)
logf := log.Warnf
if fails == 0 || fails > 1 {
logf = log.Debugf
go func() {
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
resultChan <- &ipsResponse{
err: err,
ips: ips,
}
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
return err
}()
var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
}
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
if len(records) == 0 {
m.mutex.Lock()
if m.records[question] == expected {
delete(m.records, question)
m.mutex.Unlock()
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.mutex.Unlock()
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
if resp.err != nil {
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
}
now := time.Now()
m.mutex.Lock()
if m.records[question] != expected {
m.mutex.Unlock()
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.records[question] = &cachedRecord{records: records, cachedAt: now}
m.mutex.Unlock()
log.Infof("refreshed mgmt cache domain=%s type=%s",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
func (m *Resolver) markRefreshing(question dns.Question) {
m.mutex.Lock()
m.refreshing[question] = &atomic.Bool{}
m.mutex.Unlock()
}
func (m *Resolver) clearRefreshing(question dns.Question) {
m.mutex.Lock()
delete(m.refreshing, question)
m.mutex.Unlock()
}
// markRefreshFailed arms the backoff and returns the new consecutive-failure
// count so callers can downgrade subsequent failure logs to debug.
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
m.mutex.Lock()
defer m.mutex.Unlock()
c, ok := m.records[question]
if !ok || c != expected {
return 0
}
c.lastFailedRefresh = time.Now()
c.consecFailures++
return c.consecFailures
}
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
// callers tell records, NODATA (nil err, no records), and failure apart.
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
return
}
// TODO: drop once every supported OS registers a fallback resolver. Safe
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
// not the system resolver, so net.DefaultResolver will not loop back.
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
return
}
// lookupRecords resolves a single record type via chain or OS. The OS branch
// arms the loop detector for the duration of its call so that ServeDNS can
// spot the OS resolver routing the recursive query back to us.
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
}
// TODO: drop once every supported OS registers a fallback resolver.
m.markRefreshing(q)
defer m.clearRefreshing(q)
return m.osLookup(ctx, d, q.Name, q.Qtype)
}
// lookupViaChain resolves via the handler chain and rewrites each RR to use
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
msg := &dns.Msg{}
msg.SetQuestion(dnsName, qtype)
msg.RecursionDesired = true
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
if err != nil {
return nil, fmt.Errorf("chain resolve: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("chain resolve returned nil response")
}
if resp.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
}
ttl := uint32(m.cacheTTL.Seconds())
owners := cnameOwners(dnsName, resp.Answer)
var filtered []dns.RR
for _, rr := range resp.Answer {
h := rr.Header()
if h.Class != dns.ClassINET || h.Rrtype != qtype {
continue
}
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
continue
}
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
filtered = append(filtered, cp)
}
}
return filtered, nil
}
// osLookup resolves a single family via net.DefaultResolver using resutil,
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
// returns (nil, nil).
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
}
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
if result.Rcode == dns.RcodeSuccess {
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
}
if result.Err != nil {
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
}
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
}
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
// so downstream resolvers don't cache an answer for longer than we will.
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
remaining := m.cacheTTL - time.Since(cachedAt)
if remaining <= 0 {
return 0
}
return uint32((remaining + time.Second - 1) / time.Second)
return resp.ips, nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
@@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock()
defer m.mutex.Unlock()
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
delete(m.records, qA)
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains
}
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
// A/AAAA records return nil.
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
return &cp
case *dns.AAAA:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
return &cp
}
return nil
}
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
// stamping ttl so the response shares no memory with the cached slice.
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
out := make([]dns.RR, 0, len(records))
for _, rr := range records {
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
out = append(out, cp)
}
}
return out
}
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
// in answer, iterating until fixed point so out-of-order chains resolve.
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
owners := map[string]bool{dnsName: true}
for {
added := false
for _, rr := range answer {
cname, ok := rr.(*dns.CNAME)
if !ok {
continue
}
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
if !owners[name] {
continue
}
target := strings.ToLower(dns.Fqdn(cname.Target))
if !owners[target] {
owners[target] = true
added = true
}
}
if !added {
return owners
}
}
}
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
func resolveCacheTTL() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
return d
}
}
return defaultTTL
}

View File

@@ -1,408 +0,0 @@
package mgmt
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
err error
hasRoot bool
onLookup func()
}
func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
hasRoot: true,
}
}
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
f.mu.Lock()
defer f.mu.Unlock()
return f.hasRoot
}
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
f.mu.Lock()
q := msg.Question[0]
key := q.Name + "|" + dns.TypeToString[q.Qtype]
f.calls[key]++
answers := f.answers[key]
err := f.err
onLookup := f.onLookup
f.mu.Unlock()
if onLookup != nil {
onLookup()
}
if err != nil {
return nil, err
}
resp := &dns.Msg{}
resp.SetReply(msg)
resp.Answer = answers
return resp, nil
}
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
f.mu.Lock()
defer f.mu.Unlock()
key := name + "|" + dns.TypeToString[qtype]
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
switch qtype {
case dns.TypeA:
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
case dns.TypeAAAA:
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
}
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls[name+"|"+dns.TypeToString[qtype]]
}
// waitFor polls the predicate until it returns true or the deadline passes.
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
t.Helper()
deadline := time.Now().Add(d)
for time.Now().Before(deadline) {
if fn() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("condition not met within %s", d)
}
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
t.Helper()
msg := new(dns.Msg)
msg.SetQuestion(name, dns.TypeA)
w := &test.MockResponseWriter{}
r.ServeDNS(w, msg)
return w.GetLastResponse()
}
func firstA(t *testing.T, resp *dns.Msg) string {
t.Helper()
require.NotNil(t, resp)
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok, "expected A record")
return a.A.String()
}
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
// Same cached entry age, different cacheTTL values: the shorter TTL must
// trigger a background refresh, the longer one must not. Proves that the
// per-Resolver cacheTTL field actually drives the stale decision.
cachedAt := time.Now().Add(-100 * time.Millisecond)
newRec := func() *cachedRecord {
return &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: cachedAt,
}
}
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount(q.Name, dns.TypeA) >= 1
})
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
})
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(), // fresh
}
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
}
// First query: serves stale immediately.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
})
// Next query should now return the refreshed IP.
waitFor(t, time.Second, func() bool {
resp := queryA(t, r, "mgmt.example.com.")
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
})
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
var inflight atomic.Int32
var maxInflight atomic.Int32
chain.onLookup = func() {
cur := inflight.Add(1)
defer inflight.Add(-1)
for {
prev := maxInflight.Load()
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
break
}
}
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
}
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
queryA(t, r, "mgmt.example.com.")
}()
}
wg.Wait()
waitFor(t, 2*time.Second, func() bool {
return inflight.Load() == 0
})
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
// First stale hit triggers a refresh attempt that fails.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
})
waitFor(t, time.Second, func() bool {
r.mutex.RLock()
defer r.mutex.RUnlock()
c, ok := r.records[q]
return ok && !c.lastFailedRefresh.IsZero()
})
// Subsequent stale hits within backoff window should not schedule more refreshes.
for i := 0; i < 10; i++ {
queryA(t, r, "mgmt.example.com.")
}
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
// With hasRoot=false the chain must not be consulted. Use a short
// deadline so the OS fallback returns quickly without waiting on a
// real network call in CI.
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
"chain must not be used when no root handler is registered at the bound priority")
}
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
// Simulate an inflight refresh.
r.markRefreshing(q)
defer r.clearRefreshing(q)
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
require.NotNil(t, inflight)
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
r.markRefreshing(q)
defer r.clearRefreshing(q)
// Multiple ServeDNS calls during the same refresh must not re-set the flag
// (CompareAndSwap from false -> true returns true only on the first call).
for range 5 {
queryA(t, r, "mgmt.example.com.")
}
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
assert.True(t, inflight.Load())
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
queryA(t, r, "mgmt.example.com.")
r.mutex.RLock()
_, ok := r.refreshing[q]
r.mutex.RUnlock()
assert.False(t, ok, "no refresh inflight means no loop tracking")
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
r.SetChainResolver(chain, 50)
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.2", firstA(t, resp))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
}

View File

@@ -6,7 +6,6 @@ import (
"net/url"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) {
assert.False(t, resolver.MatchSubdomains())
}
func TestResolveCacheTTL(t *testing.T) {
tests := []struct {
name string
value string
want time.Duration
}{
{"unset falls back to default", "", defaultTTL},
{"valid duration", "45s", 45 * time.Second},
{"valid minutes", "2m", 2 * time.Minute},
{"malformed falls back to default", "not-a-duration", defaultTTL},
{"zero falls back to default", "0s", defaultTTL},
{"negative falls back to default", "-5s", defaultTTL},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envMgmtCacheTTL, tc.value)
got := resolveCacheTTL()
assert.Equal(t, tc.want, got, "parsed TTL should match")
})
}
}
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
func TestResolver_ResponseTTL(t *testing.T) {
now := time.Now()
tests := []struct {
name string
cacheTTL time.Duration
cachedAt time.Time
wantMin uint32
wantMax uint32
}{
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := &Resolver{cacheTTL: tc.cacheTTL}
got := r.responseTTL(tc.cachedAt)
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
})
}
}
func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct {
name string

View File

@@ -212,7 +212,6 @@ func newDefaultServer(
ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{
ctx: ctx,

View File

@@ -1874,6 +1874,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics
}
// WGTuning bundles runtime-adjustable WireGuard pool knobs.
// See Engine.SetWGTuning. Nil fields are ignored.
type WGTuning struct {
PreallocatedBuffersPerPool *uint32
}
// SetWGTuning applies the given tuning to this engine's live Device.
func (e *Engine) SetWGTuning(t WGTuning) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.wgInterface == nil {
return fmt.Errorf("wg interface not initialized")
}
dev := e.wgInterface.GetWGDevice()
if dev == nil {
return fmt.Errorf("wg device not initialized")
}
if t.PreallocatedBuffersPerPool != nil {
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
}
return nil
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -457,18 +457,6 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
// which is when Receive has actually returned. Without this, the Receive goroutine
// can outlive the test and call t.Logf after teardown, panicking.
receiveDone := make(chan struct{})
t.Cleanup(func() {
select {
case <-receiveDone:
case <-time.After(2 * time.Second):
t.Error("Receive goroutine did not exit after Close")
}
})
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
@@ -480,7 +468,6 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
receivedAfterReconnect := make(chan struct{})
go func() {
defer close(receiveDone)
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if msg.IsInitiator || len(msg.EventId) == 0 {
return nil

2
go.mod
View File

@@ -314,7 +314,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

4
go.sum
View File

@@ -459,8 +459,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58 h1:6REpBYpJBLTTgqCcLGpTqvRDoEoLbA5r2nAXqMd2La0=
github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=

View File

@@ -472,7 +472,7 @@ start_services_and_show_instructions() {
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
echo "Registering CrowdSec bouncer..."
local cs_retries=0
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do
cs_retries=$((cs_retries + 1))
if [[ $cs_retries -ge 30 ]]; then
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr

View File

@@ -12,7 +12,6 @@ import (
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/management-integrations/integrations"
serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
@@ -88,14 +87,17 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
switch authType {
case "bearer":
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
request, err := m.checkJWTFromRequest(r, authHeader)
if err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
h.ServeHTTP(w, request)
case "token":
if err := m.checkPATFromRequest(r, authHeader); err != nil {
request, err := m.checkPATFromRequest(r, authHeader)
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
// Check if it's a status error, otherwise default to Unauthorized
if _, ok := status.FromError(err); !ok {
@@ -104,7 +106,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
util.WriteError(r.Context(), err, w)
return
}
h.ServeHTTP(w, r)
h.ServeHTTP(w, request)
default:
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
@@ -113,19 +115,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
token, err := getTokenFromJWTRequest(authHeaderParts)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("error extracting token: %w", err)
return r, fmt.Errorf("error extracting token: %w", err)
}
ctx := r.Context()
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
if err != nil {
return err
return r, err
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
@@ -141,7 +143,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := m.ensureAccount(ctx, userAuth)
if err != nil {
return err
return r, err
}
if userAuth.AccountId != accountId {
@@ -151,7 +153,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
if err != nil {
return err
return r, err
}
err = m.syncUserJWTGroups(ctx, userAuth)
@@ -162,19 +164,17 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
_, err = m.getUserFromUserAuth(ctx, userAuth)
if err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
return err
return r, err
}
// propagates ctx change to upstream middleware
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
return nil
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
token, err := getTokenFromPATRequest(authHeaderParts)
if err != nil {
return fmt.Errorf("error extracting token: %w", err)
return r, fmt.Errorf("error extracting token: %w", err)
}
if m.patUsageTracker != nil {
@@ -183,22 +183,22 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
if m.rateLimiter != nil && !isTerraformRequest(r) {
if !m.rateLimiter.Allow(token) {
return status.Errorf(status.TooManyRequests, "too many requests")
return r, status.Errorf(status.TooManyRequests, "too many requests")
}
}
ctx := r.Context()
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
return r, fmt.Errorf("invalid Token: %w", err)
}
if time.Now().After(pat.GetExpirationDate()) {
return fmt.Errorf("token expired")
return r, fmt.Errorf("token expired")
}
err = m.authManager.MarkPATUsed(ctx, pat.ID)
if err != nil {
return err
return r, err
}
userAuth := auth.UserAuth{
@@ -216,9 +216,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
}
}
// propagates ctx change to upstream middleware
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
return nil
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
}
func isTerraformRequest(r *http.Request) bool {

View File

@@ -193,12 +193,20 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
}
})
// Hold on to req so auth's in-place ctx update is visible after ServeHTTP.
req := r.WithContext(ctx)
h.ServeHTTP(w, req)
h.ServeHTTP(w, r.WithContext(ctx))
close(handlerDone)
ctx = req.Context()
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
if err == nil {
if userAuth.AccountId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
}
if userAuth.UserId != "" {
//nolint
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
}
}
if w.Status() > 399 {
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())

View File

@@ -99,6 +99,35 @@ var debugStopCmd = &cobra.Command{
SilenceUsage: true,
}
var debugWGTuneCmd = &cobra.Command{
Use: "wgtune",
Short: "Inspect and live-tune WireGuard pool settings",
}
var debugWGTuneGetCmd = &cobra.Command{
Use: "get",
Short: "Show pool cap and batch size defaults",
Args: cobra.NoArgs,
RunE: runDebugWGTuneGet,
SilenceUsage: true,
}
var debugWGTuneSetCmd = &cobra.Command{
Use: "set <pool-cap>",
Short: "Set the pool cap (new and live clients)",
Args: cobra.ExactArgs(1),
RunE: runDebugWGTuneSet,
SilenceUsage: true,
}
var debugRuntimeCmd = &cobra.Command{
Use: "runtime",
Short: "Show runtime stats (heap, goroutines, RSS)",
Args: cobra.NoArgs,
RunE: runDebugRuntime,
SilenceUsage: true,
}
func init() {
debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address")
debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format")
@@ -119,6 +148,10 @@ func init() {
debugCmd.AddCommand(debugLogCmd)
debugCmd.AddCommand(debugStartCmd)
debugCmd.AddCommand(debugStopCmd)
debugWGTuneCmd.AddCommand(debugWGTuneGetCmd)
debugWGTuneCmd.AddCommand(debugWGTuneSetCmd)
debugCmd.AddCommand(debugWGTuneCmd)
debugCmd.AddCommand(debugRuntimeCmd)
rootCmd.AddCommand(debugCmd)
}
@@ -171,3 +204,19 @@ func runDebugStart(cmd *cobra.Command, args []string) error {
func runDebugStop(cmd *cobra.Command, args []string) error {
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
}
func runDebugWGTuneGet(cmd *cobra.Command, _ []string) error {
return getDebugClient(cmd).WGTuneGet(cmd.Context())
}
func runDebugWGTuneSet(cmd *cobra.Command, args []string) error {
n, err := strconv.ParseUint(args[0], 10, 32)
if err != nil {
return fmt.Errorf("invalid value %q: %w", args[0], err)
}
return getDebugClient(cmd).WGTuneSet(cmd.Context(), uint32(n))
}
func runDebugRuntime(cmd *cobra.Command, _ []string) error {
return getDebugClient(cmd).Runtime(cmd.Context())
}

View File

@@ -15,11 +15,22 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy"
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/util"
)
const (
// envWGPreallocatedBuffers caps the per-Device WireGuard buffer pool
// size. Zero (unset) keeps the uncapped upstream default.
envWGPreallocatedBuffers = "NB_WG_PREALLOCATED_BUFFERS"
// envWGMaxBatchSize overrides the per-Device WireGuard batch size,
// which controls how many buffers each receive/TUN worker eagerly
// allocates. Zero (unset) keeps the bind+tun default.
envWGMaxBatchSize = "NB_WG_MAX_BATCH_SIZE"
)
const DefaultManagementURL = "https://api.netbird.io:443"
// envProxyToken is the environment variable name for the proxy access token.
@@ -145,6 +156,42 @@ func runServer(cmd *cobra.Command, args []string) error {
logger.Infof("configured log level: %s", level)
var wgPool, wgBatch uint64
if raw := os.Getenv(envWGPreallocatedBuffers); raw != "" {
n, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
return fmt.Errorf("invalid %s %q: %w", envWGPreallocatedBuffers, raw, err)
}
wgPool = n
embed.SetWGDefaultPreallocatedBuffersPerPool(uint32(n))
logger.Infof("wireguard preallocated buffers per pool: %d", n)
}
if raw := os.Getenv(envWGMaxBatchSize); raw != "" {
n, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
return fmt.Errorf("invalid %s %q: %w", envWGMaxBatchSize, raw, err)
}
wgBatch = n
embed.SetWGDefaultMaxBatchSize(uint32(n))
logger.Infof("wireguard max batch size override: %d", n)
}
if wgPool > 0 {
// Each bind recv goroutine (IPv4 + IPv6 + ICE relay) plus
// RoutineReadFromTUN eagerly reserves `batch` message buffers for
// the lifetime of the Device. A pool cap below that floor blocks
// the receive pipeline at startup.
batch := wgBatch
if batch == 0 {
batch = 128
}
const recvGoroutines = 4
floor := batch * recvGoroutines
if wgPool < floor {
logger.Warnf("%s=%d is below the eager-allocation floor (~%d for batch=%d); startup may deadlock",
envWGPreallocatedBuffers, wgPool, floor, batch)
}
}
switch forwardedProto {
case "auto", "http", "https":
default:

View File

@@ -272,6 +272,74 @@ func (c *Client) printLogLevelResult(data map[string]any) {
}
}
// WGTuneGet fetches the current WireGuard pool cap.
func (c *Client) WGTuneGet(ctx context.Context) error {
return c.fetchAndPrint(ctx, "/debug/wgtune", c.printWGTuneGet)
}
// WGTuneSet updates the WireGuard pool cap on the global default and all live clients.
func (c *Client) WGTuneSet(ctx context.Context, value uint32) error {
path := fmt.Sprintf("/debug/wgtune?value=%d", value)
return c.fetchAndPrint(ctx, path, c.printWGTuneSet)
}
func (c *Client) printWGTuneGet(data map[string]any) {
def, _ := data["default"].(float64)
batch, _ := data["batch_size"].(float64)
_, _ = fmt.Fprintf(c.out, "Default: %d\n", uint32(def))
_, _ = fmt.Fprintf(c.out, "Batch size: %d (0 = unset)\n", uint32(batch))
}
func (c *Client) printWGTuneSet(data map[string]any) {
if errMsg, ok := data["error"].(string); ok && errMsg != "" {
c.printError(data)
return
}
def, _ := data["default"].(float64)
applied, _ := data["applied"].(float64)
_, _ = fmt.Fprintf(c.out, "Default set to: %d\n", uint32(def))
_, _ = fmt.Fprintf(c.out, "Applied to %d live clients\n", int(applied))
if failed, ok := data["failed"].(map[string]any); ok && len(failed) > 0 {
_, _ = fmt.Fprintln(c.out, "Failed:")
for k, v := range failed {
_, _ = fmt.Fprintf(c.out, " %s: %v\n", k, v)
}
}
}
// Runtime fetches runtime stats (heap, goroutines, RSS).
func (c *Client) Runtime(ctx context.Context) error {
return c.fetchAndPrint(ctx, "/debug/runtime", c.printRuntime)
}
func (c *Client) printRuntime(data map[string]any) {
i := func(k string) uint64 {
v, _ := data[k].(float64)
return uint64(v)
}
mb := func(n uint64) string { return fmt.Sprintf("%.1f MB", float64(n)/(1<<20)) }
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
_, _ = fmt.Fprintf(c.out, "Go: %v on %d CPU (GOMAXPROCS=%d)\n", data["go_version"], uint32(i("num_cpu")), uint32(i("gomaxprocs")))
_, _ = fmt.Fprintf(c.out, "Goroutines: %d\n", i("goroutines"))
_, _ = fmt.Fprintf(c.out, "Live objects: %d\n", i("live_objects"))
_, _ = fmt.Fprintf(c.out, "GC: %d cycles, %v pause total\n", i("num_gc"), time.Duration(i("pause_total_ns")))
_, _ = fmt.Fprintln(c.out, "Heap:")
_, _ = fmt.Fprintf(c.out, " alloc: %s\n", mb(i("heap_alloc")))
_, _ = fmt.Fprintf(c.out, " in-use: %s\n", mb(i("heap_inuse")))
_, _ = fmt.Fprintf(c.out, " idle: %s\n", mb(i("heap_idle")))
_, _ = fmt.Fprintf(c.out, " released: %s\n", mb(i("heap_released")))
_, _ = fmt.Fprintf(c.out, " sys: %s\n", mb(i("heap_sys")))
_, _ = fmt.Fprintf(c.out, "Total sys: %s\n", mb(i("sys")))
if _, ok := data["vm_rss"]; ok {
_, _ = fmt.Fprintln(c.out, "Process:")
_, _ = fmt.Fprintf(c.out, " VmRSS: %s\n", mb(i("vm_rss")))
_, _ = fmt.Fprintf(c.out, " VmSize: %s\n", mb(i("vm_size")))
_, _ = fmt.Fprintf(c.out, " VmData: %s\n", mb(i("vm_data")))
}
_, _ = fmt.Fprintf(c.out, "Clients: %d (%d started)\n", i("clients"), i("started"))
}
// StartClient starts a specific client.
func (c *Client) StartClient(ctx context.Context, accountID string) error {
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"

View File

@@ -10,6 +10,8 @@ import (
"html/template"
"maps"
"net/http"
"os"
"runtime"
"slices"
"strconv"
"strings"
@@ -58,6 +60,7 @@ func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.A
type clientProvider interface {
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
ListClientsForStartup() map[types.AccountID]*nbembed.Client
}
// healthChecker provides health probe state.
@@ -139,6 +142,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.handleListClients(w, r, wantJSON)
case "/debug/health":
h.handleHealth(w, r, wantJSON)
case "/debug/wgtune":
h.handleWGTune(w, r)
case "/debug/runtime":
h.handleRuntime(w, r)
default:
if h.handleClientRoutes(w, r, path, wantJSON) {
return
@@ -230,10 +237,10 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
}
if wantJSON {
clientsJSON := make([]map[string]interface{}, 0, len(clients))
clientsJSON := make([]map[string]any, 0, len(clients))
for _, id := range sortedIDs {
info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{
clientsJSON = append(clientsJSON, map[string]any{
"account_id": info.AccountID,
"service_count": info.ServiceCount,
"service_keys": info.ServiceKeys,
@@ -242,7 +249,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
resp := map[string]interface{}{
resp := map[string]any{
"version": version.NetbirdVersion(),
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
@@ -320,10 +327,10 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
sortedIDs := sortedAccountIDs(clients)
if wantJSON {
clientsJSON := make([]map[string]interface{}, 0, len(clients))
clientsJSON := make([]map[string]any, 0, len(clients))
for _, id := range sortedIDs {
info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{
clientsJSON = append(clientsJSON, map[string]any{
"account_id": info.AccountID,
"service_count": info.ServiceCount,
"service_keys": info.ServiceKeys,
@@ -332,7 +339,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
"clients": clientsJSON,
@@ -418,7 +425,7 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
})
if wantJSON {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"account_id": accountID,
"status": overview.FullDetailSummary(),
})
@@ -501,20 +508,20 @@ func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, acco
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
h.writeJSON(w, map[string]any{"error": "client not found"})
return
}
host := r.URL.Query().Get("host")
portStr := r.URL.Query().Get("port")
if host == "" || portStr == "" {
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"})
h.writeJSON(w, map[string]any{"error": "host and port parameters required"})
return
}
port, err := strconv.Atoi(portStr)
if err != nil || port < 1 || port > 65535 {
h.writeJSON(w, map[string]interface{}{"error": "invalid port"})
h.writeJSON(w, map[string]any{"error": "invalid port"})
return
}
@@ -533,7 +540,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
conn, err := client.Dial(ctx, "tcp", address)
if err != nil {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": false,
"host": host,
"port": port,
@@ -546,7 +553,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
}
latency := time.Since(start)
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": true,
"host": host,
"port": port,
@@ -558,25 +565,25 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
h.writeJSON(w, map[string]any{"error": "client not found"})
return
}
level := r.URL.Query().Get("level")
if level == "" {
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"})
h.writeJSON(w, map[string]any{"error": "level parameter required (trace, debug, info, warn, error)"})
return
}
if err := client.SetLogLevel(level); err != nil {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": true,
"level": level,
})
@@ -587,7 +594,7 @@ const clientActionTimeout = 30 * time.Second
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
h.writeJSON(w, map[string]any{"error": "client not found"})
return
}
@@ -595,14 +602,14 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
defer cancel()
if err := client.Start(ctx); err != nil {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": true,
"message": "client started",
})
@@ -611,7 +618,7 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
h.writeJSON(w, map[string]any{"error": "client not found"})
return
}
@@ -619,19 +626,136 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou
defer cancel()
if err := client.Stop(ctx); err != nil {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": true,
"message": "client stopped",
})
}
func (h *Handler) handleWGTune(w http.ResponseWriter, r *http.Request) {
values, ok := r.URL.Query()["value"]
if !ok {
h.writeJSON(w, map[string]any{
"default": nbembed.WGDefaultPreallocatedBuffersPerPool(),
"batch_size": nbembed.WGDefaultMaxBatchSize(),
})
return
}
if len(values) == 0 || values[0] == "" {
http.Error(w, "value parameter must not be empty", http.StatusBadRequest)
return
}
raw := values[0]
n, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
http.Error(w, fmt.Sprintf("invalid value %q: %v", raw, err), http.StatusBadRequest)
return
}
nbembed.SetWGDefaultPreallocatedBuffersPerPool(uint32(n))
applied := 0
failed := map[string]string{}
for accountID, client := range h.provider.ListClientsForStartup() {
capN := uint32(n)
if err := client.SetWGTuning(nbembed.WGTuning{PreallocatedBuffersPerPool: &capN}); err != nil {
failed[string(accountID)] = err.Error()
continue
}
applied++
}
resp := map[string]any{
"success": true,
"default": uint32(n),
"batch_size": nbembed.WGDefaultMaxBatchSize(),
"applied": applied,
}
if len(failed) > 0 {
resp["failed"] = failed
}
h.writeJSON(w, resp)
}
// handleRuntime returns cheap runtime and process stats. Safe to hit on a
// running proxy; does not read pprof profiles.
func (h *Handler) handleRuntime(w http.ResponseWriter, _ *http.Request) {
var m runtime.MemStats
runtime.ReadMemStats(&m)
clients := h.provider.ListClientsForDebug()
started := 0
for _, c := range clients {
if c.HasClient {
started++
}
}
resp := map[string]any{
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"goroutines": runtime.NumGoroutine(),
"num_cpu": runtime.NumCPU(),
"gomaxprocs": runtime.GOMAXPROCS(0),
"go_version": runtime.Version(),
"heap_alloc": m.HeapAlloc,
"heap_inuse": m.HeapInuse,
"heap_idle": m.HeapIdle,
"heap_released": m.HeapReleased,
"heap_sys": m.HeapSys,
"sys": m.Sys,
"live_objects": m.Mallocs - m.Frees,
"num_gc": m.NumGC,
"pause_total_ns": m.PauseTotalNs,
"clients": len(clients),
"started": started,
}
if proc := readProcStatus(); proc != nil {
resp["vm_rss"] = proc["VmRSS"]
resp["vm_size"] = proc["VmSize"]
resp["vm_data"] = proc["VmData"]
}
h.writeJSON(w, resp)
}
// readProcStatus parses /proc/self/status on Linux and returns size fields
// in bytes. Returns nil on non-Linux or read failure.
func readProcStatus() map[string]uint64 {
raw, err := os.ReadFile("/proc/self/status")
if err != nil {
return nil
}
out := map[string]uint64{}
for _, line := range strings.Split(string(raw), "\n") {
k, v, ok := strings.Cut(line, ":")
if !ok {
continue
}
if k != "VmRSS" && k != "VmSize" && k != "VmData" {
continue
}
fields := strings.Fields(v)
if len(fields) < 1 {
continue
}
n, err := strconv.ParseUint(fields[0], 10, 64)
if err != nil {
continue
}
// Values are reported in kB.
out[k] = n * 1024
}
return out
}
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) {
if !wantJSON {
http.Redirect(w, r, "/debug", http.StatusSeeOther)
@@ -685,7 +809,7 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON
h.writeJSON(w, resp)
}
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data any) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
tmpl := h.getTemplates()
if tmpl == nil {
@@ -698,7 +822,7 @@ func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interf
}
}
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) {
func (h *Handler) writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
enc := json.NewEncoder(w)
enc.SetIndent("", " ")

View File

@@ -30,8 +30,6 @@ import (
const ConnectTimeout = 10 * time.Second
const healthCheckTimeout = 5 * time.Second
const (
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB)
// for the management client connection. Value is in bytes.
@@ -534,7 +532,7 @@ func (c *GrpcClient) IsHealthy() bool {
case connectivity.Ready:
}
ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout)
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
defer cancel()
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})

View File

@@ -23,8 +23,6 @@ import (
"github.com/netbirdio/netbird/util/wsproxy"
)
const healthCheckTimeout = 5 * time.Second
// ConnStateNotifier is a wrapper interface of the status recorder
type ConnStateNotifier interface {
MarkSignalDisconnected(error)
@@ -265,7 +263,7 @@ func (c *GrpcClient) IsHealthy() bool {
case connectivity.Ready:
}
ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout)
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
defer cancel()
_, err := c.realClient.Send(ctx, &proto.EncryptedMessage{
Key: c.key.PublicKey().String(),